@ -9,7 +9,10 @@ import (
"encoding/base64"
"encoding/base64"
"errors"
"errors"
"fmt"
"fmt"
"io"
"net"
"net/http"
"net/http"
"time"
"nhooyr.io/websocket"
"nhooyr.io/websocket"
"tailscale.com/control/controlbase"
"tailscale.com/control/controlbase"
@ -18,16 +21,20 @@ import (
"tailscale.com/types/key"
"tailscale.com/types/key"
)
)
// AcceptHTTP upgrades the HTTP request given by w and r into a
// AcceptHTTP upgrades the HTTP request given by w and r into a Tailscale
// Tailscale control protocol base transport connection.
// control protocol base transport connection.
//
//
// AcceptHTTP always writes an HTTP response to w. The caller must not
// AcceptHTTP always writes an HTTP response to w. The caller must not attempt
// attempt their own response after calling AcceptHTTP.
// their own response after calling AcceptHTTP.
//
//
// extraHeader optionally specifies extra header(s) to send in the
// earlyWrite optionally specifies a func to write to the noise connection
// 101 Switching Protocols Upgrade response. It must not include the "Upgrade"
// (encrypted). It receives the negotiated version and a writer to write to, if
// or "Connection" headers; they will be replaced.
// desired.
func AcceptHTTP ( ctx context . Context , w http . ResponseWriter , r * http . Request , private key . MachinePrivate , extraHeader http . Header ) ( * controlbase . Conn , error ) {
func AcceptHTTP ( ctx context . Context , w http . ResponseWriter , r * http . Request , private key . MachinePrivate , earlyWrite func ( protocolVersion int , w io . Writer ) error ) ( * controlbase . Conn , error ) {
return acceptHTTP ( ctx , w , r , private , earlyWrite )
}
func acceptHTTP ( ctx context . Context , w http . ResponseWriter , r * http . Request , private key . MachinePrivate , earlyWrite func ( protocolVersion int , w io . Writer ) error ) ( _ * controlbase . Conn , retErr error ) {
next := r . Header . Get ( "Upgrade" )
next := r . Header . Get ( "Upgrade" )
if next == "" {
if next == "" {
http . Error ( w , "missing next protocol" , http . StatusBadRequest )
http . Error ( w , "missing next protocol" , http . StatusBadRequest )
@ -58,9 +65,6 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri
return nil , errors . New ( "can't hijack client connection" )
return nil , errors . New ( "can't hijack client connection" )
}
}
for k , vv := range extraHeader {
w . Header ( ) [ k ] = vv
}
w . Header ( ) . Set ( "Upgrade" , upgradeHeaderValue )
w . Header ( ) . Set ( "Upgrade" , upgradeHeaderValue )
w . Header ( ) . Set ( "Connection" , "upgrade" )
w . Header ( ) . Set ( "Connection" , "upgrade" )
w . WriteHeader ( http . StatusSwitchingProtocols )
w . WriteHeader ( http . StatusSwitchingProtocols )
@ -69,18 +73,41 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri
if err != nil {
if err != nil {
return nil , fmt . Errorf ( "hijacking client connection: %w" , err )
return nil , fmt . Errorf ( "hijacking client connection: %w" , err )
}
}
defer func ( ) {
if retErr != nil {
conn . Close ( )
}
} ( )
if err := brw . Flush ( ) ; err != nil {
if err := brw . Flush ( ) ; err != nil {
conn . Close ( )
return nil , fmt . Errorf ( "flushing hijacked HTTP buffer: %w" , err )
return nil , fmt . Errorf ( "flushing hijacked HTTP buffer: %w" , err )
}
}
conn = netutil . NewDrainBufConn ( conn , brw . Reader )
conn = netutil . NewDrainBufConn ( conn , brw . Reader )
nc , err := controlbase . Server ( ctx , conn , private , init )
cwc := newWriteCorkingConn ( conn )
nc , err := controlbase . Server ( ctx , cwc , private , init )
if err != nil {
if err != nil {
conn . Close ( )
return nil , fmt . Errorf ( "noise handshake failed: %w" , err )
return nil , fmt . Errorf ( "noise handshake failed: %w" , err )
}
}
if earlyWrite != nil {
if deadline , ok := ctx . Deadline ( ) ; ok {
if err := conn . SetDeadline ( deadline ) ; err != nil {
return nil , fmt . Errorf ( "setting conn deadline: %w" , err )
}
defer conn . SetDeadline ( time . Time { } )
}
if err := earlyWrite ( nc . ProtocolVersion ( ) , nc ) ; err != nil {
return nil , err
}
}
if err := cwc . uncork ( ) ; err != nil {
return nil , err
}
return nc , nil
return nc , nil
}
}
@ -128,3 +155,61 @@ func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request
return nc , nil
return nc , nil
}
}
// corkConn is a net.Conn wrapper that initially buffers all writes until uncork
// is called. If the conn is corked and a Read occurs, the Read will flush any
// buffered (corked) write.
//
// Until uncorked, Read/Write/uncork may be not called concurrently.
//
// Deadlines still work, but a corked write ignores deadlines until a Read or
// uncork goes to do that Write.
//
// Use newWriteCorkingConn to create one.
type corkConn struct {
net . Conn
corked bool
buf [ ] byte // corked data
}
func newWriteCorkingConn ( c net . Conn ) * corkConn {
return & corkConn { Conn : c , corked : true }
}
func ( c * corkConn ) Write ( b [ ] byte ) ( int , error ) {
if c . corked {
c . buf = append ( c . buf , b ... )
return len ( b ) , nil
}
return c . Conn . Write ( b )
}
func ( c * corkConn ) Read ( b [ ] byte ) ( int , error ) {
if c . corked {
if err := c . flush ( ) ; err != nil {
return 0 , err
}
}
return c . Conn . Read ( b )
}
// uncork flushes any buffered data and uncorks the connection so future Writes
// don't buffer. It may not be called concurrently with reads or writes and
// may only be called once.
func ( c * corkConn ) uncork ( ) error {
if ! c . corked {
panic ( "usage error; uncork called twice" ) // worth panicking to catch misuse
}
err := c . flush ( )
c . corked = false
return err
}
func ( c * corkConn ) flush ( ) error {
if len ( c . buf ) == 0 {
return nil
}
_ , err := c . Conn . Write ( c . buf )
c . buf = nil
return err
}