diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index cf6fd9987..006f2614a 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -61,6 +61,7 @@ import ( type Direct struct { httpc *http.Client // HTTP client used to talk to tailcontrol dialer *tsdial.Dialer + dnsCache *dnscache.Resolver serverURL string // URL of the tailcontrol server timeNow func() time.Time lastPrintMap time.Time @@ -199,6 +200,14 @@ func NewDirect(opts Options) (*Direct, error) { opts.Logf = log.Printf } + dnsCache := &dnscache.Resolver{ + Forward: dnscache.Get().Forward, // use default cache's forwarder + UseLastGood: true, + LookupIPFallback: dnsfallback.MakeLookupFunc(opts.Logf, opts.NetMon), + Logf: opts.Logf, + NetMon: opts.NetMon, + } + httpc := opts.HTTPTestClient if httpc == nil && runtime.GOOS == "js" { // In js/wasm, net/http.Transport (as of Go 1.18) will @@ -208,13 +217,6 @@ func NewDirect(opts Options) (*Direct, error) { httpc = http.DefaultClient } if httpc == nil { - dnsCache := &dnscache.Resolver{ - Forward: dnscache.Get().Forward, // use default cache's forwarder - UseLastGood: true, - LookupIPFallback: dnsfallback.MakeLookupFunc(opts.Logf, opts.NetMon), - Logf: opts.Logf, - NetMon: opts.NetMon, - } tr := http.DefaultTransport.(*http.Transport).Clone() tr.Proxy = tshttpproxy.ProxyFromEnvironment tshttpproxy.SetTransportGetProxyConnectHeader(tr) @@ -250,6 +252,7 @@ func NewDirect(opts Options) (*Direct, error) { onControlTime: opts.OnControlTime, c2nHandler: opts.C2NHandler, dialer: opts.Dialer, + dnsCache: dnsCache, dialPlan: opts.DialPlan, } if opts.Hostinfo == nil { @@ -1509,7 +1512,16 @@ func (c *Direct) getNoiseClient() (*NoiseClient, error) { return nil, err } c.logf("creating new noise client") - nc, err := NewNoiseClient(k, serverNoiseKey, c.serverURL, c.dialer, c.logf, c.netMon, dp) + nc, err := NewNoiseClient(NoiseOpts{ + PrivKey: k, + ServerPubKey: serverNoiseKey, + ServerURL: c.serverURL, + Dialer: c.dialer, + DNSCache: c.dnsCache, + Logf: c.logf, + NetMon: c.netMon, + DialPlan: dp, + }) if err != nil { return nil, err } diff --git a/control/controlclient/noise.go b/control/controlclient/noise.go index 61c472a35..cad81b82c 100644 --- a/control/controlclient/noise.go +++ b/control/controlclient/noise.go @@ -19,6 +19,7 @@ import ( "golang.org/x/net/http2" "tailscale.com/control/controlbase" "tailscale.com/control/controlhttp" + "tailscale.com/net/dnscache" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" "tailscale.com/tailcfg" @@ -158,6 +159,7 @@ type NoiseClient struct { sfDial singleflight.Group[struct{}, *noiseConn] dialer *tsdial.Dialer + dnsCache *dnscache.Resolver privKey key.MachinePrivate serverPubKey key.MachinePublic host string // the host part of serverURL @@ -179,13 +181,39 @@ type NoiseClient struct { connPool map[int]*noiseConn // active connections not yet closed; see noiseConn.Close } +// NoiseOpts contains options for the NewNoiseClient function. All fields are +// required unless otherwise specified. +type NoiseOpts struct { + // PrivKey is this node's private key. + PrivKey key.MachinePrivate + // ServerPubKey is the public key of the server. + ServerPubKey key.MachinePublic + // ServerURL is the URL of the server to connect to. + ServerURL string + // Dialer's SystemDial function is used to connect to the server. + Dialer *tsdial.Dialer + // DNSCache is the caching Resolver to use to connect to the server. + // + // This field can be nil. + DNSCache *dnscache.Resolver + // Logf is the log function to use. This field can be nil. + Logf logger.Logf + // NetMon is the network monitor that, if set, will be used to get the + // network interface state. This field can be nil; if so, the current + // state will be looked up dynamically. + NetMon *netmon.Monitor + // DialPlan, if set, is a function that should return an explicit plan + // on how to connect to the server. + DialPlan func() *tailcfg.ControlDialPlan +} + // NewNoiseClient returns a new noiseClient for the provided server and machine key. // serverURL is of the form https://: (no trailing slash). // // netMon may be nil, if non-nil it's used to do faster interface lookups. // dialPlan may be nil -func NewNoiseClient(privKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer, logf logger.Logf, netMon *netmon.Monitor, dialPlan func() *tailcfg.ControlDialPlan) (*NoiseClient, error) { - u, err := url.Parse(serverURL) +func NewNoiseClient(opts NoiseOpts) (*NoiseClient, error) { + u, err := url.Parse(opts.ServerURL) if err != nil { return nil, err } @@ -205,16 +233,18 @@ func NewNoiseClient(privKey key.MachinePrivate, serverPubKey key.MachinePublic, httpPort = "80" httpsPort = "443" } + np := &NoiseClient{ - serverPubKey: serverPubKey, - privKey: privKey, + serverPubKey: opts.ServerPubKey, + privKey: opts.PrivKey, host: u.Hostname(), httpPort: httpPort, httpsPort: httpsPort, - dialer: dialer, - dialPlan: dialPlan, - logf: logf, - netMon: netMon, + dialer: opts.Dialer, + dnsCache: opts.DNSCache, + dialPlan: opts.DialPlan, + logf: opts.Logf, + netMon: opts.NetMon, } // Create the HTTP/2 Transport using a net/http.Transport @@ -373,6 +403,7 @@ func (nc *NoiseClient) dial() (*noiseConn, error) { ControlKey: nc.serverPubKey, ProtocolVersion: uint16(tailcfg.CurrentCapabilityVersion), Dialer: nc.dialer.SystemDial, + DNSCache: nc.dnsCache, DialPlan: dialPlan, Logf: nc.logf, NetMon: nc.netMon, diff --git a/control/controlclient/noise_test.go b/control/controlclient/noise_test.go index 11e35f0af..9961e3318 100644 --- a/control/controlclient/noise_test.go +++ b/control/controlclient/noise_test.go @@ -74,7 +74,12 @@ func (tt noiseClientTest) run(t *testing.T) { defer hs.Close() dialer := new(tsdial.Dialer) - nc, err := NewNoiseClient(clientPrivate, serverPrivate.Public(), hs.URL, dialer, nil, nil, nil) + nc, err := NewNoiseClient(NoiseOpts{ + PrivKey: clientPrivate, + ServerPubKey: serverPrivate.Public(), + ServerURL: hs.URL, + Dialer: dialer, + }) if err != nil { t.Fatal(err) } diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index d04aac518..b0d91bada 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -374,6 +374,22 @@ func (a *Dialer) dialURL(ctx context.Context, u *url.URL, addr netip.Addr) (*Cli }, nil } +// resolver returns a.DNSCache if non-nil or a new *dnscache.Resolver +// otherwise. +func (a *Dialer) resolver() *dnscache.Resolver { + if a.DNSCache != nil { + return a.DNSCache + } + + return &dnscache.Resolver{ + Forward: dnscache.Get().Forward, + LookupIPFallback: dnsfallback.MakeLookupFunc(a.logf, a.NetMon), + UseLastGood: true, + Logf: a.Logf, // not a.logf method; we want to propagate nil-ness + NetMon: a.NetMon, + } +} + // tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn. If addr // is valid, then no DNS is used and the connection will be made to the // provided address. @@ -392,13 +408,7 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr, NetMon: a.NetMon, } } else { - dns = &dnscache.Resolver{ - Forward: dnscache.Get().Forward, - LookupIPFallback: dnsfallback.MakeLookupFunc(a.logf, a.NetMon), - UseLastGood: true, - Logf: a.Logf, // not a.logf method; we want to propagate nil-ness - NetMon: a.NetMon, - } + dns = a.resolver() } var dialer dnscache.DialContextFunc diff --git a/control/controlhttp/constants.go b/control/controlhttp/constants.go index a58ee5374..b838f84c4 100644 --- a/control/controlhttp/constants.go +++ b/control/controlhttp/constants.go @@ -67,6 +67,11 @@ type Dialer struct { // If not specified, this defaults to net.Dialer.DialContext. Dialer dnscache.DialContextFunc + // DNSCache is the caching Resolver used by this Dialer. + // + // If not specified, a new Resolver is created per attempt. + DNSCache *dnscache.Resolver + // Logf, if set, is a logging function to use; if unset, logs are // dropped. Logf logger.Logf