diff --git a/control/controlknobs/controlknobs.go b/control/controlknobs/controlknobs.go index 3ea0575a5..e64bc8011 100644 --- a/control/controlknobs/controlknobs.go +++ b/control/controlknobs/controlknobs.go @@ -48,6 +48,10 @@ type Knobs struct { // PeerMTUEnable is whether the node should do peer path MTU discovery. PeerMTUEnable atomic.Bool + + // DisableDNSForwarderTCPRetries is whether the DNS forwarder should + // skip retrying truncated queries over TCP. + DisableDNSForwarderTCPRetries atomic.Bool } // UpdateFromNodeAttributes updates k (if non-nil) based on the provided self @@ -61,14 +65,15 @@ func (k *Knobs) UpdateFromNodeAttributes(selfNodeAttrs []tailcfg.NodeCapability, return ok || slices.Contains(selfNodeAttrs, attr) } var ( - keepFullWG = has(tailcfg.NodeAttrDebugDisableWGTrim) - disableDRPO = has(tailcfg.NodeAttrDebugDisableDRPO) - disableUPnP = has(tailcfg.NodeAttrDisableUPnP) - randomizeClientPort = has(tailcfg.NodeAttrRandomizeClientPort) - disableDeltaUpdates = has(tailcfg.NodeAttrDisableDeltaUpdates) - oneCGNAT opt.Bool - forceBackgroundSTUN = has(tailcfg.NodeAttrDebugForceBackgroundSTUN) - peerMTUEnable = has(tailcfg.NodeAttrPeerMTUEnable) + keepFullWG = has(tailcfg.NodeAttrDebugDisableWGTrim) + disableDRPO = has(tailcfg.NodeAttrDebugDisableDRPO) + disableUPnP = has(tailcfg.NodeAttrDisableUPnP) + randomizeClientPort = has(tailcfg.NodeAttrRandomizeClientPort) + disableDeltaUpdates = has(tailcfg.NodeAttrDisableDeltaUpdates) + oneCGNAT opt.Bool + forceBackgroundSTUN = has(tailcfg.NodeAttrDebugForceBackgroundSTUN) + peerMTUEnable = has(tailcfg.NodeAttrPeerMTUEnable) + dnsForwarderDisableTCPRetries = has(tailcfg.NodeAttrDNSForwarderDisableTCPRetries) ) if has(tailcfg.NodeAttrOneCGNATEnable) { @@ -85,6 +90,7 @@ func (k *Knobs) UpdateFromNodeAttributes(selfNodeAttrs []tailcfg.NodeCapability, k.ForceBackgroundSTUN.Store(forceBackgroundSTUN) k.DisableDeltaUpdates.Store(disableDeltaUpdates) k.PeerMTUEnable.Store(peerMTUEnable) + k.DisableDNSForwarderTCPRetries.Store(dnsForwarderDisableTCPRetries) } // AsDebugJSON returns k as something that can be marshalled with json.Marshal @@ -94,13 +100,14 @@ func (k *Knobs) AsDebugJSON() map[string]any { return nil } return map[string]any{ - "DisableUPnP": k.DisableUPnP.Load(), - "DisableDRPO": k.DisableDRPO.Load(), - "KeepFullWGConfig": k.KeepFullWGConfig.Load(), - "RandomizeClientPort": k.RandomizeClientPort.Load(), - "OneCGNAT": k.OneCGNAT.Load(), - "ForceBackgroundSTUN": k.ForceBackgroundSTUN.Load(), - "DisableDeltaUpdates": k.DisableDeltaUpdates.Load(), - "PeerMTUEnable": k.PeerMTUEnable.Load(), + "DisableUPnP": k.DisableUPnP.Load(), + "DisableDRPO": k.DisableDRPO.Load(), + "KeepFullWGConfig": k.KeepFullWGConfig.Load(), + "RandomizeClientPort": k.RandomizeClientPort.Load(), + "OneCGNAT": k.OneCGNAT.Load(), + "ForceBackgroundSTUN": k.ForceBackgroundSTUN.Load(), + "DisableDeltaUpdates": k.DisableDeltaUpdates.Load(), + "PeerMTUEnable": k.PeerMTUEnable.Load(), + "DisableDNSForwarderTCPRetries": k.DisableDNSForwarderTCPRetries.Load(), } } diff --git a/net/dns/manager.go b/net/dns/manager.go index cee7d7ede..1b5903b59 100644 --- a/net/dns/manager.go +++ b/net/dns/manager.go @@ -17,6 +17,7 @@ import ( "sync/atomic" "time" + "tailscale.com/control/controlknobs" "tailscale.com/health" "tailscale.com/net/dns/resolver" "tailscale.com/net/netmon" @@ -66,14 +67,14 @@ type Manager struct { // NewManagers created a new manager from the given config. // The netMon parameter is optional; if non-nil it's used to do faster interface lookups. -func NewManager(logf logger.Logf, oscfg OSConfigurator, netMon *netmon.Monitor, dialer *tsdial.Dialer, linkSel resolver.ForwardLinkSelector) *Manager { +func NewManager(logf logger.Logf, oscfg OSConfigurator, netMon *netmon.Monitor, dialer *tsdial.Dialer, linkSel resolver.ForwardLinkSelector, knobs *controlknobs.Knobs) *Manager { if dialer == nil { panic("nil Dialer") } logf = logger.WithPrefix(logf, "dns: ") m := &Manager{ logf: logf, - resolver: resolver.New(logf, netMon, linkSel, dialer), + resolver: resolver.New(logf, netMon, linkSel, dialer, knobs), os: oscfg, } m.ctx, m.ctxCancel = context.WithCancel(context.Background()) @@ -295,7 +296,10 @@ func toIPsOnly(resolvers []*dnstype.Resolver) (ret []netip.Addr) { // Query executes a DNS query received from the given address. The query is // provided in bs as a wire-encoded DNS query without any transport header. // This method is called for requests arriving over UDP and TCP. -func (m *Manager) Query(ctx context.Context, bs []byte, from netip.AddrPort) ([]byte, error) { +// +// The "family" parameter should indicate what type of DNS query this is: +// either "tcp" or "udp". +func (m *Manager) Query(ctx context.Context, bs []byte, family string, from netip.AddrPort) ([]byte, error) { select { case <-m.ctx.Done(): return nil, net.ErrClosed @@ -309,7 +313,7 @@ func (m *Manager) Query(ctx context.Context, bs []byte, from netip.AddrPort) ([] return nil, errFullQueue } defer atomic.AddInt32(&m.activeQueriesAtomic, -1) - return m.resolver.Query(ctx, bs, from) + return m.resolver.Query(ctx, bs, family, from) } const ( @@ -371,7 +375,7 @@ func (s *dnsTCPSession) handleWrites() { } func (s *dnsTCPSession) handleQuery(q []byte) { - resp, err := s.m.Query(s.ctx, q, s.srcAddr) + resp, err := s.m.Query(s.ctx, q, "tcp", s.srcAddr) if err != nil { s.m.logf("tcp query: %v", err) return @@ -466,7 +470,7 @@ func Cleanup(logf logger.Logf, interfaceName string) { logf("creating dns cleanup: %v", err) return } - dns := NewManager(logf, oscfg, nil, &tsdial.Dialer{Logf: logf}, nil) + dns := NewManager(logf, oscfg, nil, &tsdial.Dialer{Logf: logf}, nil, nil) if err := dns.Down(); err != nil { logf("dns down: %v", err) } diff --git a/net/dns/manager_tcp_test.go b/net/dns/manager_tcp_test.go index 29e2e8a5b..87c0b258e 100644 --- a/net/dns/manager_tcp_test.go +++ b/net/dns/manager_tcp_test.go @@ -87,7 +87,7 @@ func TestDNSOverTCP(t *testing.T) { SearchDomains: fqdns("coffee.shop"), }, } - m := NewManager(t.Logf, &f, nil, new(tsdial.Dialer), nil) + m := NewManager(t.Logf, &f, nil, new(tsdial.Dialer), nil, nil) m.resolver.TestOnlySetHook(f.SetResolver) m.Set(Config{ Hosts: hosts( @@ -172,7 +172,7 @@ func TestDNSOverTCP_TooLarge(t *testing.T) { SearchDomains: fqdns("coffee.shop"), }, } - m := NewManager(log, &f, nil, new(tsdial.Dialer), nil) + m := NewManager(log, &f, nil, new(tsdial.Dialer), nil, nil) m.resolver.TestOnlySetHook(f.SetResolver) m.Set(Config{ Hosts: hosts("andrew.ts.com.", "1.2.3.4"), diff --git a/net/dns/manager_test.go b/net/dns/manager_test.go index 7997c4317..e88b61322 100644 --- a/net/dns/manager_test.go +++ b/net/dns/manager_test.go @@ -613,7 +613,7 @@ func TestManager(t *testing.T) { SplitDNS: test.split, BaseConfig: test.bs, } - m := NewManager(t.Logf, &f, nil, new(tsdial.Dialer), nil) + m := NewManager(t.Logf, &f, nil, new(tsdial.Dialer), nil, nil) m.resolver.TestOnlySetHook(f.SetResolver) if err := m.Set(test.in); err != nil { diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 85670e1d6..edcf9bbe6 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -21,6 +21,7 @@ import ( "time" dns "golang.org/x/net/dns/dnsmessage" + "tailscale.com/control/controlknobs" "tailscale.com/envknob" "tailscale.com/net/dns/publicdns" "tailscale.com/net/dnscache" @@ -68,6 +69,10 @@ const ( // DNS queries to the "fallback" DNS server IP for a known provider // (e.g. how long to wait to query Google's 8.8.4.4 after 8.8.8.8). wellKnownHostBackupDelay = 200 * time.Millisecond + + // tcpQueryTimeout is the timeout for a DNS query performed over TCP. + // It matches the default 5sec timeout of the 'dig' utility. + tcpQueryTimeout = 5 * time.Second ) // txid identifies a DNS transaction. @@ -180,6 +185,8 @@ type forwarder struct { linkSel ForwardLinkSelector // TODO(bradfitz): remove this when tsdial.Dialer absorbs it dialer *tsdial.Dialer + controlKnobs *controlknobs.Knobs // or nil + ctx context.Context // good until Close ctxCancel context.CancelFunc // closes ctx @@ -206,12 +213,13 @@ func init() { rand.Seed(time.Now().UnixNano()) } -func newForwarder(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, dialer *tsdial.Dialer) *forwarder { +func newForwarder(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, dialer *tsdial.Dialer, knobs *controlknobs.Knobs) *forwarder { f := &forwarder{ - logf: logger.WithPrefix(logf, "forward: "), - netMon: netMon, - linkSel: linkSel, - dialer: dialer, + logf: logger.WithPrefix(logf, "forward: "), + netMon: netMon, + linkSel: linkSel, + dialer: dialer, + controlKnobs: knobs, } f.ctx, f.ctxCancel = context.WithCancel(context.Background()) return f @@ -443,7 +451,10 @@ func (f *forwarder) sendDoH(ctx context.Context, urlBase string, c *http.Client, return res, err } -var verboseDNSForward = envknob.RegisterBool("TS_DEBUG_DNS_FORWARD_SEND") +var ( + verboseDNSForward = envknob.RegisterBool("TS_DEBUG_DNS_FORWARD_SEND") + skipTCPRetry = envknob.RegisterBool("TS_DNS_FORWARD_SKIP_TCP_RETRY") +) // send sends packet to dst. It is best effort. // @@ -477,10 +488,49 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDe return nil, fmt.Errorf("tls:// resolvers not supported yet") } - return f.sendUDP(ctx, fq, rr) + ret, err = f.sendUDP(ctx, fq, rr) + if err != nil { + return nil, err + } + + if !truncatedFlagSet(ret) { + // Successful, non-truncated response; return it. + return ret, nil + } + if fq.family == "udp" { + // If this is a UDP query, return it regardless of whether the + // response is truncated or not; the client can retry + // communicating with tailscaled over TCP. There's no point + // falling back to TCP for a truncated query if we can't return + // the results to the client. + return ret, nil + } + if skipTCPRetry() || (f.controlKnobs != nil && f.controlKnobs.DisableDNSForwarderTCPRetries.Load()) { + // Envknob or control knob disabled the TCP retry behaviour; + // just return what we have. + return ret, nil + } + + // Don't retry if our context is done. + if err := ctx.Err(); err != nil { + return nil, err + } + + // Retry over TCP, best-effort; return the truncated UDP response if we + // cannot query via TCP. + if ret2, err2 := f.sendTCP(ctx, fq, rr); err2 == nil { + if verboseDNSForward() { + f.logf("forwarder.send(%q): successfully retried via TCP", rr.name.Addr) + } + return ret2, nil + } else if verboseDNSForward() { + f.logf("forwarder.send(%q): could not retry via TCP: %v", rr.name.Addr, err2) + } + return ret, nil } var errServerFailure = errors.New("response code indicates server issue") +var errTxIDMismatch = errors.New("txid doesn't match") func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) { ipp, ok := rr.name.IPPort() @@ -545,7 +595,7 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn txid := getTxID(out) if txid != fq.txid { metricDNSFwdUDPErrorTxID.Add(1) - return nil, errors.New("txid doesn't match") + return nil, errTxIDMismatch } rcode := getRCode(out) // don't forward transient errors back to the client when the server fails @@ -577,6 +627,92 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn return out, nil } +func (f *forwarder) sendTCP(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) { + ipp, ok := rr.name.IPPort() + if !ok { + metricDNSFwdErrorType.Add(1) + return nil, fmt.Errorf("unrecognized resolver type %q", rr.name.Addr) + } + metricDNSFwdTCP.Add(1) + ctx = sockstats.WithSockStats(ctx, sockstats.LabelDNSForwarderTCP, f.logf) + + // Specify the exact family to work around https://github.com/golang/go/issues/52264 + tcpFam := "tcp4" + if ipp.Addr().Is6() { + tcpFam = "tcp6" + } + + ctx, cancel := context.WithTimeout(ctx, tcpQueryTimeout) + defer cancel() + + conn, err := f.dialer.SystemDial(ctx, tcpFam, ipp.String()) + if err != nil { + return nil, err + } + defer conn.Close() + + fq.closeOnCtxDone.Add(conn) + defer fq.closeOnCtxDone.Remove(conn) + + ctxOrErr := func(err2 error) ([]byte, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + return nil, err2 + } + + // Write the query to the server. + query := make([]byte, len(fq.packet)+2) + binary.BigEndian.PutUint16(query, uint16(len(fq.packet))) + copy(query[2:], fq.packet) + if _, err := conn.Write(query); err != nil { + metricDNSFwdTCPErrorWrite.Add(1) + return ctxOrErr(err) + } + + metricDNSFwdTCPWrote.Add(1) + + // Read the header length back from the server + var length uint16 + if err := binary.Read(conn, binary.BigEndian, &length); err != nil { + metricDNSFwdTCPErrorRead.Add(1) + return ctxOrErr(err) + } + + // Now read the response + out := make([]byte, length) + n, err := io.ReadFull(conn, out) + if err != nil { + metricDNSFwdTCPErrorRead.Add(1) + return ctxOrErr(err) + } + + if n < int(length) { + f.logf("sendTCP: packet too small (%d bytes)", n) + return nil, io.ErrUnexpectedEOF + } + out = out[:n] + txid := getTxID(out) + if txid != fq.txid { + metricDNSFwdTCPErrorTxID.Add(1) + return nil, errTxIDMismatch + } + + rcode := getRCode(out) + + // don't forward transient errors back to the client when the server fails + if rcode == dns.RCodeServerFailure { + f.logf("sendTCP: response code indicating server failure: %d", rcode) + metricDNSFwdTCPErrorServer.Add(1) + return nil, errServerFailure + } + + // TODO(andrew): do we need to do this? + //clampEDNSSize(out, maxResponseBytes) + metricDNSFwdTCPSuccess.Add(1) + return out, nil +} + // resolvers returns the resolvers to use for domain. func (f *forwarder) resolvers(domain dnsname.FQDN) []resolverAndDelay { f.mu.Lock() @@ -601,6 +737,7 @@ func (f *forwarder) resolvers(domain dnsname.FQDN) []resolverAndDelay { type forwardQuery struct { txid txid packet []byte + family string // "tcp" or "udp" // closeOnCtxDone lets send register values to Close if the // caller's ctx expires. This avoids send from allocating its @@ -686,6 +823,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo fq := &forwardQuery{ txid: getTxID(query.bs), packet: query.bs, + family: query.family, closeOnCtxDone: new(closePool), } defer fq.closeOnCtxDone.Close() @@ -727,7 +865,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo case <-ctx.Done(): metricDNSFwdErrorContext.Add(1) return ctx.Err() - case responseChan <- packet{v, query.addr}: + case responseChan <- packet{v, query.family, query.addr}: metricDNSFwdSuccess.Add(1) return nil } diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index dad165fe7..b78e26c95 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -4,14 +4,26 @@ package resolver import ( + "bytes" + "context" + "encoding/binary" "flag" "fmt" + "io" + "net" + "net/netip" + "os" "reflect" "strings" + "sync" + "sync/atomic" "testing" "time" dns "golang.org/x/net/dns/dnsmessage" + "tailscale.com/envknob" + "tailscale.com/net/netmon" + "tailscale.com/net/tsdial" "tailscale.com/types/dnstype" ) @@ -240,3 +252,224 @@ func FuzzClampEDNSSize(f *testing.F) { clampEDNSSize(data, maxResponseBytes) }) } + +func runDNSServer(tb testing.TB, response []byte, onRequest func(bool, []byte)) (port uint16) { + tcpResponse := make([]byte, len(response)+2) + binary.BigEndian.PutUint16(tcpResponse, uint16(len(response))) + copy(tcpResponse[2:], response) + + // Repeatedly listen until we can get the same port. + const tries = 25 + var ( + tcpLn *net.TCPListener + udpLn *net.UDPConn + err error + ) + for try := 0; try < tries; try++ { + if tcpLn != nil { + tcpLn.Close() + tcpLn = nil + } + + tcpLn, err = net.ListenTCP("tcp4", &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 0, // Choose one + }) + if err != nil { + tb.Fatal(err) + } + udpLn, err = net.ListenUDP("udp4", &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: tcpLn.Addr().(*net.TCPAddr).Port, + }) + if err == nil { + break + } + } + if tcpLn == nil || udpLn == nil { + if tcpLn != nil { + tcpLn.Close() + } + if udpLn != nil { + udpLn.Close() + } + + // Skip instead of being fatal to avoid flaking on extremely + // heavily-loaded CI systems. + tb.Skipf("failed to listen on same port for TCP/UDP after %d tries", tries) + } + + port = uint16(tcpLn.Addr().(*net.TCPAddr).Port) + + handleConn := func(conn net.Conn) { + defer conn.Close() + + // Read the length header, then the buffer + var length uint16 + if err := binary.Read(conn, binary.BigEndian, &length); err != nil { + tb.Logf("error reading length header: %v", err) + return + } + req := make([]byte, length) + n, err := io.ReadFull(conn, req) + if err != nil { + tb.Logf("error reading query: %v", err) + return + } + req = req[:n] + onRequest(true, req) + + // Write response + if _, err := conn.Write(tcpResponse); err != nil { + tb.Logf("error writing response: %v", err) + return + } + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for { + conn, err := tcpLn.Accept() + if err != nil { + return + } + go handleConn(conn) + } + }() + + handleUDP := func(addr netip.AddrPort, req []byte) { + onRequest(false, req) + if _, err := udpLn.WriteToUDPAddrPort(response, addr); err != nil { + tb.Logf("error writing response: %v", err) + } + } + + wg.Add(1) + go func() { + defer wg.Done() + for { + buf := make([]byte, 65535) + n, addr, err := udpLn.ReadFromUDPAddrPort(buf) + if err != nil { + return + } + buf = buf[:n] + go handleUDP(addr, buf) + } + }() + + tb.Cleanup(func() { + tcpLn.Close() + udpLn.Close() + tb.Logf("waiting for listeners to finish...") + wg.Wait() + }) + return +} + +func TestForwarderTCPFallback(t *testing.T) { + const debugKnob = "TS_DEBUG_DNS_FORWARD_SEND" + oldVal := os.Getenv(debugKnob) + envknob.Setenv(debugKnob, "true") + t.Cleanup(func() { envknob.Setenv(debugKnob, oldVal) }) + + const domain = "large-dns-response.tailscale.com." + + // Make a response that's very large, containing a bunch of localhost addresses. + largeResponse := func() []byte { + name := dns.MustNewName(domain) + + builder := dns.NewBuilder(nil, dns.Header{}) + builder.StartQuestions() + builder.Question(dns.Question{ + Name: name, + Type: dns.TypeA, + Class: dns.ClassINET, + }) + builder.StartAnswers() + for i := 0; i < 120; i++ { + builder.AResource(dns.ResourceHeader{ + Name: name, + Class: dns.ClassINET, + TTL: 300, + }, dns.AResource{ + A: [4]byte{127, 0, 0, byte(i)}, + }) + } + + msg, err := builder.Finish() + if err != nil { + t.Fatal(err) + } + return msg + }() + if len(largeResponse) <= maxResponseBytes { + t.Fatalf("got len(largeResponse)=%d, want > %d", len(largeResponse), maxResponseBytes) + } + + // Our request is a single A query for the domain in the answer, above. + request := func() []byte { + builder := dns.NewBuilder(nil, dns.Header{}) + builder.StartQuestions() + builder.Question(dns.Question{ + Name: dns.MustNewName(domain), + Type: dns.TypeA, + Class: dns.ClassINET, + }) + msg, err := builder.Finish() + if err != nil { + t.Fatal(err) + } + return msg + }() + + var sawUDPRequest, sawTCPRequest atomic.Bool + port := runDNSServer(t, largeResponse, func(isTCP bool, gotRequest []byte) { + if isTCP { + sawTCPRequest.Store(true) + } else { + sawUDPRequest.Store(true) + } + + if !bytes.Equal(request, gotRequest) { + t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request) + } + }) + + netMon, err := netmon.New(t.Logf) + if err != nil { + t.Fatal(err) + } + + var dialer tsdial.Dialer + dialer.SetNetMon(netMon) + + fwd := newForwarder(t.Logf, netMon, nil, &dialer, nil) + + fq := &forwardQuery{ + txid: getTxID(request), + packet: request, + closeOnCtxDone: new(closePool), + } + defer fq.closeOnCtxDone.Close() + + rr := resolverAndDelay{ + name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)}, + } + + resp, err := fwd.send(context.Background(), fq, rr) + if err != nil { + t.Fatalf("error making request: %v", err) + } + if !bytes.Equal(resp, largeResponse) { + t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) + } + if !sawTCPRequest.Load() { + t.Errorf("DNS server never saw TCP request") + } + if !sawUDPRequest.Load() { + t.Errorf("DNS server never saw UDP request") + } +} diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index 2b5a0869e..7c2af5b16 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -23,6 +23,7 @@ import ( "time" dns "golang.org/x/net/dns/dnsmessage" + "tailscale.com/control/controlknobs" "tailscale.com/envknob" "tailscale.com/net/dns/resolvconffile" "tailscale.com/net/netaddr" @@ -53,8 +54,9 @@ var ( ) type packet struct { - bs []byte - addr netip.AddrPort // src for a request, dst for a response + bs []byte + family string // either "tcp" or "udp" + addr netip.AddrPort // src for a request, dst for a response } // Config is a resolver configuration. @@ -206,7 +208,7 @@ type ForwardLinkSelector interface { // New returns a new resolver. // netMon optionally specifies a network monitor to use for socket rebinding. -func New(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, dialer *tsdial.Dialer) *Resolver { +func New(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, dialer *tsdial.Dialer, knobs *controlknobs.Knobs) *Resolver { if dialer == nil { panic("nil Dialer") } @@ -218,7 +220,7 @@ func New(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, ipToHost: map[netip.Addr]dnsname.FQDN{}, dialer: dialer, } - r.forwarder = newForwarder(r.logf, netMon, linkSel, dialer) + r.forwarder = newForwarder(r.logf, netMon, linkSel, dialer, knobs) return r } @@ -266,7 +268,7 @@ func (r *Resolver) Close() { // bound on per-query resource usage. const dnsQueryTimeout = 10 * time.Second -func (r *Resolver) Query(ctx context.Context, bs []byte, from netip.AddrPort) ([]byte, error) { +func (r *Resolver) Query(ctx context.Context, bs []byte, family string, from netip.AddrPort) ([]byte, error) { metricDNSQueryLocal.Add(1) select { case <-r.closed: @@ -281,7 +283,7 @@ func (r *Resolver) Query(ctx context.Context, bs []byte, from netip.AddrPort) ([ ctx, cancel := context.WithTimeout(ctx, dnsQueryTimeout) defer close(responses) defer cancel() - err = r.forwarder.forwardWithDestChan(ctx, packet{bs, from}, responses) + err = r.forwarder.forwardWithDestChan(ctx, packet{bs, family, from}, responses) if err != nil { select { // Best effort: use any error response sent by forwardWithDestChan. @@ -369,7 +371,7 @@ func (r *Resolver) HandleExitNodeDNSQuery(ctx context.Context, q []byte, from ne }} } - err = r.forwarder.forwardWithDestChan(ctx, packet{q, from}, ch, resolvers...) + err = r.forwarder.forwardWithDestChan(ctx, packet{q, "tcp", from}, ch, resolvers...) if err != nil { metricDNSExitProxyErrorForward.Add(1) return nil, err @@ -1306,6 +1308,14 @@ var ( metricDNSFwdUDPErrorRead = clientmetric.NewCounter("dns_query_fwd_udp_error_read") metricDNSFwdUDPSuccess = clientmetric.NewCounter("dns_query_fwd_udp_success") + metricDNSFwdTCP = clientmetric.NewCounter("dns_query_fwd_tcp") // on entry + metricDNSFwdTCPWrote = clientmetric.NewCounter("dns_query_fwd_tcp_wrote") // sent TCP packet + metricDNSFwdTCPErrorWrite = clientmetric.NewCounter("dns_query_fwd_tcp_error_write") + metricDNSFwdTCPErrorServer = clientmetric.NewCounter("dns_query_fwd_tcp_error_server") + metricDNSFwdTCPErrorTxID = clientmetric.NewCounter("dns_query_fwd_tcp_error_txid") + metricDNSFwdTCPErrorRead = clientmetric.NewCounter("dns_query_fwd_tcp_error_read") + metricDNSFwdTCPSuccess = clientmetric.NewCounter("dns_query_fwd_tcp_success") + metricDNSFwdDoH = clientmetric.NewCounter("dns_query_fwd_doh") metricDNSFwdDoHErrorStatus = clientmetric.NewCounter("dns_query_fwd_doh_error_status") metricDNSFwdDoHErrorCT = clientmetric.NewCounter("dns_query_fwd_doh_error_content_type") diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index 8db97f49a..c135c92a1 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -233,7 +233,7 @@ func unpackResponse(payload []byte) (dnsResponse, error) { } func syncRespond(r *Resolver, query []byte) ([]byte, error) { - return r.Query(context.Background(), query, netip.AddrPort{}) + return r.Query(context.Background(), query, "udp", netip.AddrPort{}) } func mustIP(str string) netip.Addr { @@ -315,7 +315,7 @@ func TestRDNSNameToIPv6(t *testing.T) { } func newResolver(t testing.TB) *Resolver { - return New(t.Logf, nil /* no network monitor */, nil /* no link selector */, new(tsdial.Dialer)) + return New(t.Logf, nil /* no network monitor */, nil /* no link selector */, new(tsdial.Dialer), nil /* no control knobs */) } func TestResolveLocal(t *testing.T) { @@ -1016,7 +1016,7 @@ func TestForwardLinkSelection(t *testing.T) { return "special" } return "" - }), new(tsdial.Dialer)) + }), new(tsdial.Dialer), nil /* no control knobs */) // Test non-special IP. if got, err := fwd.packetListener(netip.Addr{}); err != nil { diff --git a/net/sockstats/label_string.go b/net/sockstats/label_string.go index 2c3fb6bd7..f9a111ad7 100644 --- a/net/sockstats/label_string.go +++ b/net/sockstats/label_string.go @@ -20,11 +20,12 @@ func _() { _ = x[LabelMagicsockConnUDP6-9] _ = x[LabelNetlogLogger-10] _ = x[LabelSockstatlogLogger-11] + _ = x[LabelDNSForwarderTCP-12] } -const _Label_name = "ControlClientAutoControlClientDialerDERPHTTPClientLogtailLoggerDNSForwarderDoHDNSForwarderUDPNetcheckClientPortmapperClientMagicsockConnUDP4MagicsockConnUDP6NetlogLoggerSockstatlogLogger" +const _Label_name = "ControlClientAutoControlClientDialerDERPHTTPClientLogtailLoggerDNSForwarderDoHDNSForwarderUDPNetcheckClientPortmapperClientMagicsockConnUDP4MagicsockConnUDP6NetlogLoggerSockstatlogLoggerDNSForwarderTCP" -var _Label_index = [...]uint8{0, 17, 36, 50, 63, 78, 93, 107, 123, 140, 157, 169, 186} +var _Label_index = [...]uint8{0, 17, 36, 50, 63, 78, 93, 107, 123, 140, 157, 169, 186, 201} func (i Label) String() string { if i >= Label(len(_Label_index)-1) { diff --git a/net/sockstats/sockstats.go b/net/sockstats/sockstats.go index b39d60afb..715c1ee06 100644 --- a/net/sockstats/sockstats.go +++ b/net/sockstats/sockstats.go @@ -51,6 +51,7 @@ const ( LabelMagicsockConnUDP6 Label = 9 // wgengine/magicsock/magicsock.go LabelNetlogLogger Label = 10 // wgengine/netlog/logger.go LabelSockstatlogLogger Label = 11 // log/sockstatlog/logger.go + LabelDNSForwarderTCP Label = 12 // net/dns/resolver/forwarder.go ) // WithSockStats instruments a context so that sockets created with it will diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 180f588a3..9ead3525b 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -114,7 +114,8 @@ type CapabilityVersion int // - 72: 2023-08-23: TS-2023-006 UPnP issue fixed; UPnP can now be used again // - 73: 2023-09-01: Non-Windows clients expect to receive ClientVersion // - 74: 2023-09-18: Client understands NodeCapMap -const CurrentCapabilityVersion CapabilityVersion = 74 +// - 75: 2023-09-12: Client understands NodeAttrDNSForwarderDisableTCPRetries +const CurrentCapabilityVersion CapabilityVersion = 75 type StableID string @@ -2137,6 +2138,10 @@ const ( // NodeAttrPeerMTUEnable makes the client do path MTU discovery to its // peers. If it isn't set, it defaults to the client default. NodeAttrPeerMTUEnable NodeCapability = "peer-mtu-enable" + + // NodeAttrDNSForwarderDisableTCPRetries disables retrying truncated + // DNS queries over TCP if the response is truncated. + NodeAttrDNSForwarderDisableTCPRetries NodeCapability = "dns-forwarder-disable-tcp-retries" ) // SetDNSRequest is a request to add a DNS record. diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index e2f930b4c..acff80fbc 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -1086,7 +1086,7 @@ func (ns *Impl) handleMagicDNSUDP(srcAddr netip.AddrPort, c *gonet.UDPConn) { } return } - resp, err := ns.dns.Query(context.Background(), q[:n], srcAddr) + resp, err := ns.dns.Query(context.Background(), q[:n], "udp", srcAddr) if err != nil { ns.logf("dns udp query: %v", err) return diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 397348ee2..071b71c91 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -304,7 +304,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) tunName, _ := conf.Tun.Name() conf.Dialer.SetTUNName(tunName) conf.Dialer.SetNetMon(e.netMon) - e.dns = dns.NewManager(logf, conf.DNS, e.netMon, conf.Dialer, fwdDNSLinkSelector{e, tunName}) + e.dns = dns.NewManager(logf, conf.DNS, e.netMon, conf.Dialer, fwdDNSLinkSelector{e, tunName}, conf.ControlKnobs) // TODO: there's probably a better place for this sockstats.SetNetMon(e.netMon)