@ -7,7 +7,6 @@ package controlclient
import (
import (
"bytes"
"bytes"
"context"
"context"
"crypto/rand"
"encoding/binary"
"encoding/binary"
"encoding/json"
"encoding/json"
"errors"
"errors"
@ -28,7 +27,7 @@ import (
"sync/atomic"
"sync/atomic"
"time"
"time"
"go lang.org/x/crypto/nacl/box "
"go 4.org/mem "
"inet.af/netaddr"
"inet.af/netaddr"
"tailscale.com/control/controlknobs"
"tailscale.com/control/controlknobs"
"tailscale.com/health"
"tailscale.com/health"
@ -42,6 +41,7 @@ import (
"tailscale.com/net/tlsdial"
"tailscale.com/net/tlsdial"
"tailscale.com/net/tshttpproxy"
"tailscale.com/net/tshttpproxy"
"tailscale.com/tailcfg"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/types/logger"
"tailscale.com/types/netmap"
"tailscale.com/types/netmap"
"tailscale.com/types/opt"
"tailscale.com/types/opt"
@ -62,14 +62,14 @@ type Direct struct {
logf logger . Logf
logf logger . Logf
linkMon * monitor . Mon // or nil
linkMon * monitor . Mon // or nil
discoPubKey tailcfg . DiscoKey
discoPubKey tailcfg . DiscoKey
getMachinePrivKey func ( ) ( wg key. Private, error )
getMachinePrivKey func ( ) ( key. Machine Private, error )
debugFlags [ ] string
debugFlags [ ] string
keepSharerAndUserSplit bool
keepSharerAndUserSplit bool
skipIPForwardingCheck bool
skipIPForwardingCheck bool
pinger Pinger
pinger Pinger
mu sync . Mutex // mutex guards the following fields
mu sync . Mutex // mutex guards the following fields
serverKey wgkey. Key
serverKey key. MachinePublic
persist persist . Persist
persist persist . Persist
authKey string
authKey string
tryingNewKey wgkey . Private
tryingNewKey wgkey . Private
@ -83,12 +83,12 @@ type Direct struct {
}
}
type Options struct {
type Options struct {
Persist persist . Persist // initial persistent data
Persist persist . Persist // initial persistent data
GetMachinePrivateKey func ( ) ( wg key. Private, error ) // returns the machine key to use
GetMachinePrivateKey func ( ) ( key. Machine Private, error ) // returns the machine key to use
ServerURL string // URL of the tailcontrol server
ServerURL string // URL of the tailcontrol server
AuthKey string // optional node auth key for auto registration
AuthKey string // optional node auth key for auto registration
TimeNow func ( ) time . Time // time.Now implementation used by Client
TimeNow func ( ) time . Time // time.Now implementation used by Client
Hostinfo * tailcfg . Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc
Hostinfo * tailcfg . Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc
DiscoPublicKey tailcfg . DiscoKey
DiscoPublicKey tailcfg . DiscoKey
NewDecompressor func ( ) ( Decompressor , error )
NewDecompressor func ( ) ( Decompressor , error )
KeepAlive bool
KeepAlive bool
@ -320,7 +320,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
if err != nil {
if err != nil {
return regen , opt . URL , err
return regen , opt . URL , err
}
}
c . logf ( "control server key %s from %s" , serverKey . Hex String( ) , c . serverURL )
c . logf ( "control server key %s from %s" , serverKey . Short String( ) , c . serverURL )
c . mu . Lock ( )
c . mu . Lock ( )
c . serverKey = serverKey
c . serverKey = serverKey
@ -398,13 +398,13 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
c . logf ( "RegisterRequest: %s" , j )
c . logf ( "RegisterRequest: %s" , j )
}
}
bodyData , err := encode ( request , & serverKey , & machinePrivKey )
bodyData , err := encode ( request , serverKey , machinePrivKey )
if err != nil {
if err != nil {
return regen , opt . URL , err
return regen , opt . URL , err
}
}
body := bytes . NewReader ( bodyData )
body := bytes . NewReader ( bodyData )
u := fmt . Sprintf ( "%s/machine/%s" , c . serverURL , machinePrivKey . Public ( ) . HexString( ) )
u := fmt . Sprintf ( "%s/machine/%s" , c . serverURL , machinePrivKey . Public ( ) . Untyped HexString( ) )
req , err := http . NewRequest ( "POST" , u , body )
req , err := http . NewRequest ( "POST" , u , body )
if err != nil {
if err != nil {
return regen , opt . URL , err
return regen , opt . URL , err
@ -422,7 +422,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new
res . StatusCode , strings . TrimSpace ( string ( msg ) ) )
res . StatusCode , strings . TrimSpace ( string ( msg ) ) )
}
}
resp := tailcfg . RegisterResponse { }
resp := tailcfg . RegisterResponse { }
if err := decode ( res , & resp , & serverKey , & machinePrivKey ) ; err != nil {
if err := decode ( res , & resp , serverKey , machinePrivKey ) ; err != nil {
c . logf ( "error decoding RegisterResponse with server key %s and machine key %s: %v" , serverKey , machinePrivKey . Public ( ) , err )
c . logf ( "error decoding RegisterResponse with server key %s and machine key %s: %v" , serverKey , machinePrivKey . Public ( ) , err )
return regen , opt . URL , fmt . Errorf ( "register request: %v" , err )
return regen , opt . URL , fmt . Errorf ( "register request: %v" , err )
}
}
@ -636,7 +636,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
request . ReadOnly = true
request . ReadOnly = true
}
}
bodyData , err := encode ( request , & serverKey , & machinePrivKey )
bodyData , err := encode ( request , serverKey , machinePrivKey )
if err != nil {
if err != nil {
vlogf ( "netmap: encode: %v" , err )
vlogf ( "netmap: encode: %v" , err )
return err
return err
@ -645,9 +645,9 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
ctx , cancel := context . WithCancel ( ctx )
ctx , cancel := context . WithCancel ( ctx )
defer cancel ( )
defer cancel ( )
machinePubKey := tailcfg. MachineKey ( machinePrivKey. Public ( ) )
machinePubKey := machinePrivKey. Public ( )
t0 := time . Now ( )
t0 := time . Now ( )
u := fmt . Sprintf ( "%s/machine/%s/map" , serverURL , machinePubKey . HexString( ) )
u := fmt . Sprintf ( "%s/machine/%s/map" , serverURL , machinePubKey . Untyped HexString( ) )
req , err := http . NewRequestWithContext ( ctx , "POST" , u , bytes . NewReader ( bodyData ) )
req , err := http . NewRequestWithContext ( ctx , "POST" , u , bytes . NewReader ( bodyData ) )
if err != nil {
if err != nil {
@ -734,7 +734,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
vlogf ( "netmap: read body after %v" , time . Since ( t0 ) . Round ( time . Millisecond ) )
vlogf ( "netmap: read body after %v" , time . Since ( t0 ) . Round ( time . Millisecond ) )
var resp tailcfg . MapResponse
var resp tailcfg . MapResponse
if err := c . decodeMsg ( msg , & resp , & machinePrivKey ) ; err != nil {
if err := c . decodeMsg ( msg , & resp , machinePrivKey ) ; err != nil {
vlogf ( "netmap: decode error: %v" )
vlogf ( "netmap: decode error: %v" )
return err
return err
}
}
@ -830,7 +830,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm
return nil
return nil
}
}
func decode ( res * http . Response , v interface { } , serverKey * wgkey . Key , mkey * wgkey . Private) error {
func decode ( res * http . Response , v interface { } , serverKey key . MachinePublic , mkey key . Machine Private) error {
defer res . Body . Close ( )
defer res . Body . Close ( )
msg , err := ioutil . ReadAll ( io . LimitReader ( res . Body , 1 << 20 ) )
msg , err := ioutil . ReadAll ( io . LimitReader ( res . Body , 1 << 20 ) )
if err != nil {
if err != nil {
@ -849,14 +849,14 @@ var (
var jsonEscapedZero = [ ] byte ( ` \u0000 ` )
var jsonEscapedZero = [ ] byte ( ` \u0000 ` )
func ( c * Direct ) decodeMsg ( msg [ ] byte , v interface { } , machinePrivKey * wg key. Private) error {
func ( c * Direct ) decodeMsg ( msg [ ] byte , v interface { } , machinePrivKey key. Machine Private) error {
c . mu . Lock ( )
c . mu . Lock ( )
serverKey := c . serverKey
serverKey := c . serverKey
c . mu . Unlock ( )
c . mu . Unlock ( )
decrypted , err := decryptMsg ( msg , & serverKey , machinePrivKey )
decrypted , ok := machinePrivKey . OpenFrom ( serverKey , msg )
if err != nil {
if ! ok {
return err
return err ors. New ( "cannot decrypt response" )
}
}
var b [ ] byte
var b [ ] byte
if c . newDecompressor == nil {
if c . newDecompressor == nil {
@ -888,10 +888,10 @@ func (c *Direct) decodeMsg(msg []byte, v interface{}, machinePrivKey *wgkey.Priv
}
}
func decodeMsg ( msg [ ] byte , v interface { } , serverKey * wgkey . Key , machinePrivKey * wgkey . Private) error {
func decodeMsg ( msg [ ] byte , v interface { } , serverKey key . MachinePublic , machinePrivKey key . Machine Private) error {
decrypted , err := decryptMsg ( msg , serverKey , machinePrivKey )
decrypted , ok := machinePrivKey . OpenFrom ( serverKey , msg )
if err != nil {
if ! ok {
return err
return err ors. New ( "cannot decrypt response" )
}
}
if bytes . Contains ( decrypted , jsonEscapedZero ) {
if bytes . Contains ( decrypted , jsonEscapedZero ) {
log . Printf ( "[unexpected] zero byte in controlclient decodeMsg into %T: %q" , v , decrypted )
log . Printf ( "[unexpected] zero byte in controlclient decodeMsg into %T: %q" , v , decrypted )
@ -902,23 +902,7 @@ func decodeMsg(msg []byte, v interface{}, serverKey *wgkey.Key, machinePrivKey *
return nil
return nil
}
}
func decryptMsg ( msg [ ] byte , serverKey * wgkey . Key , mkey * wgkey . Private ) ( [ ] byte , error ) {
func encode ( v interface { } , serverKey key . MachinePublic , mkey key . MachinePrivate ) ( [ ] byte , error ) {
var nonce [ 24 ] byte
if len ( msg ) < len ( nonce ) + 1 {
return nil , fmt . Errorf ( "response missing nonce, len=%d" , len ( msg ) )
}
copy ( nonce [ : ] , msg )
msg = msg [ len ( nonce ) : ]
pub , pri := ( * [ 32 ] byte ) ( serverKey ) , ( * [ 32 ] byte ) ( mkey )
decrypted , ok := box . Open ( nil , msg , & nonce , pub , pri )
if ! ok {
return nil , fmt . Errorf ( "cannot decrypt response (len %d + nonce %d = %d)" , len ( msg ) , len ( nonce ) , len ( msg ) + len ( nonce ) )
}
return decrypted , nil
}
func encode ( v interface { } , serverKey * wgkey . Key , mkey * wgkey . Private ) ( [ ] byte , error ) {
b , err := json . Marshal ( v )
b , err := json . Marshal ( v )
if err != nil {
if err != nil {
return nil , err
return nil , err
@ -928,38 +912,32 @@ func encode(v interface{}, serverKey *wgkey.Key, mkey *wgkey.Private) ([]byte, e
log . Printf ( "MapRequest: %s" , b )
log . Printf ( "MapRequest: %s" , b )
}
}
}
}
var nonce [ 24 ] byte
return mkey . SealTo ( serverKey , b ) , nil
if _ , err := io . ReadFull ( rand . Reader , nonce [ : ] ) ; err != nil {
panic ( err )
}
pub , pri := ( * [ 32 ] byte ) ( serverKey ) , ( * [ 32 ] byte ) ( mkey )
msg := box . Seal ( nonce [ : ] , b , & nonce , pub , pri )
return msg , nil
}
}
func loadServerKey ( ctx context . Context , httpc * http . Client , serverURL string ) ( wgkey. Key , error ) {
func loadServerKey ( ctx context . Context , httpc * http . Client , serverURL string ) ( key . MachinePublic , error ) {
req , err := http . NewRequest ( "GET" , serverURL + "/key" , nil )
req , err := http . NewRequest ( "GET" , serverURL + "/key" , nil )
if err != nil {
if err != nil {
return wgkey. Key { } , fmt . Errorf ( "create control key request: %v" , err )
return key . MachinePublic { } , fmt . Errorf ( "create control key request: %v" , err )
}
}
req = req . WithContext ( ctx )
req = req . WithContext ( ctx )
res , err := httpc . Do ( req )
res , err := httpc . Do ( req )
if err != nil {
if err != nil {
return wgkey. Key { } , fmt . Errorf ( "fetch control key: %v" , err )
return key. MachinePublic { } , fmt . Errorf ( "fetch control key: %v" , err )
}
}
defer res . Body . Close ( )
defer res . Body . Close ( )
b , err := ioutil . ReadAll ( io . LimitReader ( res . Body , 1 << 16 ) )
b , err := ioutil . ReadAll ( io . LimitReader ( res . Body , 1 << 16 ) )
if err != nil {
if err != nil {
return wgkey. Key { } , fmt . Errorf ( "fetch control key response: %v" , err )
return key. MachinePublic { } , fmt . Errorf ( "fetch control key response: %v" , err )
}
}
if res . StatusCode != 200 {
if res . StatusCode != 200 {
return wgkey. Key { } , fmt . Errorf ( "fetch control key: %d: %s" , res . StatusCode , string ( b ) )
return key. MachinePublic { } , fmt . Errorf ( "fetch control key: %d: %s" , res . StatusCode , string ( b ) )
}
}
k ey, err := wgkey . ParseHex ( string ( b ) )
k , err := key . ParseMachinePublicUntyped ( mem . B ( b ) )
if err != nil {
if err != nil {
return wgkey. Key { } , fmt . Errorf ( "fetch control key: %v" , err )
return key. MachinePublic { } , fmt . Errorf ( "fetch control key: %v" , err )
}
}
return k ey , nil
return k , nil
}
}
// Debug contains temporary internal-only debug knobs.
// Debug contains temporary internal-only debug knobs.
@ -1207,13 +1185,13 @@ func (c *Direct) SetDNS(ctx context.Context, req *tailcfg.SetDNSRequest) error {
return errors . New ( "getMachinePrivKey returned zero key" )
return errors . New ( "getMachinePrivKey returned zero key" )
}
}
bodyData , err := encode ( req , & serverKey , & machinePrivKey )
bodyData , err := encode ( req , serverKey , machinePrivKey )
if err != nil {
if err != nil {
return err
return err
}
}
body := bytes . NewReader ( bodyData )
body := bytes . NewReader ( bodyData )
u := fmt . Sprintf ( "%s/machine/%s/set-dns" , c . serverURL , machinePrivKey . Public ( ) . HexString( ) )
u := fmt . Sprintf ( "%s/machine/%s/set-dns" , c . serverURL , machinePrivKey . Public ( ) . Untyped HexString( ) )
hreq , err := http . NewRequestWithContext ( ctx , "POST" , u , body )
hreq , err := http . NewRequestWithContext ( ctx , "POST" , u , body )
if err != nil {
if err != nil {
return err
return err
@ -1228,7 +1206,7 @@ func (c *Direct) SetDNS(ctx context.Context, req *tailcfg.SetDNSRequest) error {
return fmt . Errorf ( "set-dns response: %v, %.200s" , res . Status , strings . TrimSpace ( string ( msg ) ) )
return fmt . Errorf ( "set-dns response: %v, %.200s" , res . Status , strings . TrimSpace ( string ( msg ) ) )
}
}
var setDNSRes struct { } // no fields yet
var setDNSRes struct { } // no fields yet
if err := decode ( res , & setDNSRes , & serverKey , & machinePrivKey ) ; err != nil {
if err := decode ( res , & setDNSRes , serverKey , machinePrivKey ) ; err != nil {
c . logf ( "error decoding SetDNSResponse with server key %s and machine key %s: %v" , serverKey , machinePrivKey . Public ( ) , err )
c . logf ( "error decoding SetDNSResponse with server key %s and machine key %s: %v" , serverKey , machinePrivKey . Public ( ) , err )
return fmt . Errorf ( "set-dns-response: %v" , err )
return fmt . Errorf ( "set-dns-response: %v" , err )
}
}