diff --git a/safesocket/pipe_windows.go b/safesocket/pipe_windows.go index 964d34da1..e8c256adf 100644 --- a/safesocket/pipe_windows.go +++ b/safesocket/pipe_windows.go @@ -15,16 +15,8 @@ func path(vendor, name string, port uint16) string { return fmt.Sprintf("127.0.0.1:%v", port) } -func ConnCloseRead(c net.Conn) error { - return c.(*net.TCPConn).CloseRead() -} - -func ConnCloseWrite(c net.Conn) error { - return c.(*net.TCPConn).CloseWrite() -} - // TODO(apenwarr): handle magic cookie auth -func Connect(path string, port uint16) (net.Conn, error) { +func connect(path string, port uint16) (net.Conn, error) { pipe, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) if err != nil { return nil, err @@ -45,7 +37,7 @@ func setFlags(network, address string, c syscall.RawConn) error { // just always using a TCP session on a fixed port on localhost. As a // result, on Windows we ignore the vendor and name strings. // TODO(apenwarr): handle magic cookie auth -func Listen(path string, port uint16) (net.Listener, uint16, error) { +func listen(path string, port uint16) (_ net.Listener, gotPort uint16, _ error) { lc := net.ListenConfig{ Control: setFlags, } diff --git a/safesocket/safesocket.go b/safesocket/safesocket.go new file mode 100644 index 000000000..81f86782b --- /dev/null +++ b/safesocket/safesocket.go @@ -0,0 +1,42 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !windows + +// Package safesocket creates either a Unix socket, if possible, or +// otherwise a localhost TCP connection. +package safesocket + +import ( + "net" +) + +type closeable interface { + CloseRead() error + CloseWrite() error +} + +// ConnCloseRead calls c's CloseRead method. c is expected to be +// either a UnixConn or TCPConn as returned from this package. +func ConnCloseRead(c net.Conn) error { + return c.(closeable).CloseRead() +} + +// ConnCloseWrite calls c's CloseWrite method. c is expected to be +// either a UnixConn or TCPConn as returned from this package. +func ConnCloseWrite(c net.Conn) error { + return c.(closeable).CloseWrite() +} + +// Connect connects to either path (on Unix) or the provided localhost port (on Windows). +func Connect(path string, port uint16) (net.Conn, error) { + return connect(path, port) +} + +// Listen returns a listener either on Unix socket path (on Unix), or +// the localhost port (on Windows). +// If port is 0, the returned gotPort says which port was selected on Windows. +func Listen(path string, port uint16) (_ net.Listener, gotPort uint16, _ error) { + return listen(path, port) +} diff --git a/safesocket/unixsocket.go b/safesocket/unixsocket.go index ac9f51d7f..8e5bd8b92 100644 --- a/safesocket/unixsocket.go +++ b/safesocket/unixsocket.go @@ -12,16 +12,8 @@ import ( "os" ) -func ConnCloseRead(c net.Conn) error { - return c.(*net.UnixConn).CloseRead() -} - -func ConnCloseWrite(c net.Conn) error { - return c.(*net.UnixConn).CloseWrite() -} - // TODO(apenwarr): handle magic cookie auth -func Connect(path string, port uint16) (net.Conn, error) { +func connect(path string, port uint16) (net.Conn, error) { pipe, err := net.Dial("unix", path) if err != nil { return nil, err @@ -30,7 +22,7 @@ func Connect(path string, port uint16) (net.Conn, error) { } // TODO(apenwarr): handle magic cookie auth -func Listen(path string, port uint16) (net.Listener, uint16, error) { +func listen(path string, port uint16) (ln net.Listener, _ uint16, err error) { // Unix sockets hang around in the filesystem even after nobody // is listening on them. (Which is really unfortunate but long- // entrenched semantics.) Try connecting first; if it works, then