From b4be4f089f5b4c7274ccc3a3978ff5b662f541d0 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 26 Oct 2023 09:14:17 -0700 Subject: [PATCH] safesocket: make clear which net.Conns are winio types Follow-up to earlier #9049. Updates #9049 Change-Id: I121fbd2468770233a23ab5ee3df42698ca1dabc2 Signed-off-by: Brad Fitzpatrick --- safesocket/basic_test.go | 6 +-- safesocket/pipe_windows.go | 46 ++++++++++------- safesocket/pipe_windows_test.go | 88 ++++++++++++++++++++++++++++++++- 3 files changed, 117 insertions(+), 23 deletions(-) diff --git a/safesocket/basic_test.go b/safesocket/basic_test.go index ebb5f2f3a..fca1c3f09 100644 --- a/safesocket/basic_test.go +++ b/safesocket/basic_test.go @@ -25,7 +25,7 @@ func TestBasics(t *testing.T) { t.Cleanup(downgradeSDDL()) } - l, err := Listen(sock) + ln, err := Listen(sock) if err != nil { t.Fatal(err) } @@ -33,12 +33,12 @@ func TestBasics(t *testing.T) { errs := make(chan error, 2) go func() { - s, err := l.Accept() + s, err := ln.Accept() if err != nil { errs <- err return } - l.Close() + ln.Close() s.Write([]byte("hello")) b := make([]byte, 1024) diff --git a/safesocket/pipe_windows.go b/safesocket/pipe_windows.go index 999929120..1259b7834 100644 --- a/safesocket/pipe_windows.go +++ b/safesocket/pipe_windows.go @@ -57,27 +57,26 @@ func listen(path string) (net.Listener, error) { // the Windows access token associated with the connection's client. The // embedded net.Conn must be a go-winio PipeConn. type WindowsClientConn struct { - net.Conn + winioPipeConn token windows.Token } -// winioPipeHandle is fulfilled by the underlying code implementing go-winio's -// PipeConn interface. -type winioPipeHandle interface { +// winioPipeConn is a subset of the interface implemented by the go-winio's +// unexported *win32pipe type, as returned by go-winio's ListenPipe +// net.Listener's Accept method. This type is used in places where we really are +// assuming that specific unexported type and its Fd method. +type winioPipeConn interface { + net.Conn // Fd returns the Windows handle associated with the connection. Fd() uintptr } -func resolvePipeHandle(c net.Conn) windows.Handle { - wph, ok := c.(winioPipeHandle) - if !ok { - return 0 - } - return windows.Handle(wph.Fd()) +func resolvePipeHandle(pc winioPipeConn) windows.Handle { + return windows.Handle(pc.Fd()) } func (conn *WindowsClientConn) handle() windows.Handle { - return resolvePipeHandle(conn.Conn) + return resolvePipeHandle(conn.winioPipeConn) } // ClientPID returns the pid of conn's client, or else an error. @@ -99,11 +98,14 @@ func (conn *WindowsClientConn) Close() error { conn.token.Close() conn.token = 0 } - return conn.Conn.Close() + return conn.winioPipeConn.Close() } +// winIOPipeListener is a net.Listener that wraps a go-winio PipeListener and +// returns net.Conn values of type *WindowsClientConn with the associated +// windows.Token. type winIOPipeListener struct { - net.Listener + net.Listener // must be from winio.ListenPipe } func (lw *winIOPipeListener) Accept() (net.Conn, error) { @@ -112,22 +114,28 @@ func (lw *winIOPipeListener) Accept() (net.Conn, error) { return nil, err } - token, err := clientUserAccessToken(conn) + pipeConn, ok := conn.(winioPipeConn) + if !ok { + conn.Close() + return nil, fmt.Errorf("unexpected type %T from winio.ListenPipe listener (itself a %T)", conn, lw.Listener) + } + + token, err := clientUserAccessToken(pipeConn) if err != nil { conn.Close() return nil, err } return &WindowsClientConn{ - Conn: conn, - token: token, + winioPipeConn: pipeConn, + token: token, }, nil } -func clientUserAccessToken(c net.Conn) (windows.Token, error) { - h := resolvePipeHandle(c) +func clientUserAccessToken(pc winioPipeConn) (windows.Token, error) { + h := resolvePipeHandle(pc) if h == 0 { - return 0, fmt.Errorf("not a windows handle: %T", c) + return 0, fmt.Errorf("clientUserAccessToken failed to get handle from pipeConn %T", pc) } // Impersonation touches thread-local state, so we need to lock until the diff --git a/safesocket/pipe_windows_test.go b/safesocket/pipe_windows_test.go index 6028adbe8..c6f8eb393 100644 --- a/safesocket/pipe_windows_test.go +++ b/safesocket/pipe_windows_test.go @@ -3,7 +3,12 @@ package safesocket -import "tailscale.com/util/winutil" +import ( + "fmt" + "testing" + + "tailscale.com/util/winutil" +) func init() { // downgradeSDDL is a test helper that downgrades the windowsSDDL variable if @@ -20,3 +25,84 @@ func init() { return func() {} } } + +// TestExpectedWindowsTypes is a copy of TestBasics specialized for Windows with +// type assertions about the types of listeners and conns we expect. +func TestExpectedWindowsTypes(t *testing.T) { + t.Cleanup(downgradeSDDL()) + const sock = `\\.\pipe\tailscale-test` + ln, err := Listen(sock) + if err != nil { + t.Fatal(err) + } + if got, want := fmt.Sprintf("%T", ln), "*safesocket.winIOPipeListener"; got != want { + t.Errorf("got listener type %q; want %q", got, want) + } + + errs := make(chan error, 2) + + go func() { + s, err := ln.Accept() + if err != nil { + errs <- err + return + } + ln.Close() + + wcc, ok := s.(*WindowsClientConn) + if !ok { + s.Close() + errs <- fmt.Errorf("accepted type %T; want WindowsClientConn", s) + return + } + if wcc.winioPipeConn.Fd() == 0 { + t.Error("accepted conn had unexpected zero fd") + } + if wcc.token == 0 { + t.Error("accepted conn had unexpected zero token") + } + + s.Write([]byte("hello")) + + b := make([]byte, 1024) + n, err := s.Read(b) + if err != nil { + errs <- err + return + } + t.Logf("server read %d bytes.", n) + if string(b[:n]) != "world" { + errs <- fmt.Errorf("got %#v, expected %#v\n", string(b[:n]), "world") + return + } + s.Close() + errs <- nil + }() + + go func() { + s := DefaultConnectionStrategy(sock) + c, err := Connect(s) + if err != nil { + errs <- err + return + } + c.Write([]byte("world")) + b := make([]byte, 1024) + n, err := c.Read(b) + if err != nil { + errs <- err + return + } + if string(b[:n]) != "hello" { + errs <- fmt.Errorf("got %#v, expected %#v\n", string(b[:n]), "hello") + } + c.Close() + errs <- nil + }() + + for i := 0; i < 2; i++ { + if err := <-errs; err != nil { + t.Fatal(err) + } + } +}