@ -13,6 +13,7 @@ import (
"encoding/binary"
"encoding/binary"
"encoding/json"
"encoding/json"
"errors"
"errors"
"flag"
"fmt"
"fmt"
"io"
"io"
"io/ioutil"
"io/ioutil"
@ -116,6 +117,7 @@ type Direct struct {
// hostinfo is mutated in-place while mu is held.
// hostinfo is mutated in-place while mu is held.
hostinfo * tailcfg . Hostinfo // always non-nil
hostinfo * tailcfg . Hostinfo // always non-nil
endpoints [ ] string
endpoints [ ] string
everEndpoints bool // whether we've ever had non-empty endpoints
localPort uint16 // or zero to mean auto
localPort uint16 // or zero to mean auto
}
}
@ -476,6 +478,9 @@ func (c *Direct) newEndpoints(localPort uint16, endpoints []string) (changed boo
c . logf ( "client.newEndpoints(%v, %v)" , localPort , endpoints )
c . logf ( "client.newEndpoints(%v, %v)" , localPort , endpoints )
c . localPort = localPort
c . localPort = localPort
c . endpoints = append ( c . endpoints [ : 0 ] , endpoints ... )
c . endpoints = append ( c . endpoints [ : 0 ] , endpoints ... )
if len ( endpoints ) > 0 {
c . everEndpoints = true
}
return true // changed
return true // changed
}
}
@ -488,6 +493,13 @@ func (c *Direct) SetEndpoints(localPort uint16, endpoints []string) (changed boo
return c . newEndpoints ( localPort , endpoints )
return c . newEndpoints ( localPort , endpoints )
}
}
func inTest ( ) bool { return flag . Lookup ( "test.v" ) != nil }
// PollNetMap makes a /map request to download the network map, calling cb with
// each new netmap.
//
// maxPolls is how many network maps to download; common values are 1
// or -1 (to keep a long-poll query open to the server).
func ( c * Direct ) PollNetMap ( ctx context . Context , maxPolls int , cb func ( * NetworkMap ) ) error {
func ( c * Direct ) PollNetMap ( ctx context . Context , maxPolls int , cb func ( * NetworkMap ) ) error {
c . mu . Lock ( )
c . mu . Lock ( )
persist := c . persist
persist := c . persist
@ -497,6 +509,7 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM
backendLogID := hostinfo . BackendLogID
backendLogID := hostinfo . BackendLogID
localPort := c . localPort
localPort := c . localPort
ep := append ( [ ] string ( nil ) , c . endpoints ... )
ep := append ( [ ] string ( nil ) , c . endpoints ... )
everEndpoints := c . everEndpoints
c . mu . Unlock ( )
c . mu . Unlock ( )
if backendLogID == "" {
if backendLogID == "" {
@ -504,7 +517,7 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM
}
}
allowStream := maxPolls != 1
allowStream := maxPolls != 1
c . logf ( "PollNetMap: stream=%v :%v %v", maxPolls , localPort , ep )
c . logf ( "PollNetMap: stream=%v :%v ep=%v", allowStream , localPort , ep )
vlogf := logger . Discard
vlogf := logger . Discard
if Debug . NetMap {
if Debug . NetMap {
@ -525,6 +538,17 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM
if c . newDecompressor != nil {
if c . newDecompressor != nil {
request . Compress = "zstd"
request . Compress = "zstd"
}
}
// On initial startup before we know our endpoints, set the ReadOnly flag
// to tell the control server not to distribute out our (empty) endpoints to peers.
// Presumably we'll learn our endpoints in a half second and do another post
// with useful results. The first POST just gets us the DERP map which we
// need to do the STUN queries to discover our endpoints.
// TODO(bradfitz): we skip this optimization in tests, though,
// because the e2e tests are currently hyperspecific about the
// ordering of things. The e2e tests need love.
if len ( ep ) == 0 && ! everEndpoints && ! inTest ( ) {
request . ReadOnly = true
}
bodyData , err := encode ( request , & serverKey , & c . machinePrivKey )
bodyData , err := encode ( request , & serverKey , & c . machinePrivKey )
if err != nil {
if err != nil {
@ -532,16 +556,17 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM
return err
return err
}
}
ctx , cancel := context . WithCancel ( ctx )
defer cancel ( )
machinePubKey := tailcfg . MachineKey ( c . machinePrivKey . Public ( ) )
machinePubKey := tailcfg . MachineKey ( c . machinePrivKey . Public ( ) )
t0 := time . Now ( )
t0 := time . Now ( )
u := fmt . Sprintf ( "%s/machine/%s/map" , serverURL , machinePubKey . HexString ( ) )
u := fmt . Sprintf ( "%s/machine/%s/map" , serverURL , machinePubKey . HexString ( ) )
req , err := http . NewRequest ( "POST" , u , bytes . NewReader ( bodyData ) )
req , err := http . NewRequestWithContext ( ctx , "POST" , u , bytes . NewReader ( bodyData ) )
if err != nil {
if err != nil {
return err
return err
}
}
ctx , cancel := context . WithCancel ( ctx )
defer cancel ( )
req = req . WithContext ( ctx )
res , err := c . httpc . Do ( req )
res , err := c . httpc . Do ( req )
if err != nil {
if err != nil {