diff --git a/net/dns/manager.go b/net/dns/manager.go index c1086da7e..7d7d29689 100644 --- a/net/dns/manager.go +++ b/net/dns/manager.go @@ -447,9 +447,17 @@ type dnsTCPSession struct { func (s *dnsTCPSession) handleWrites() { defer s.conn.Close() - defer close(s.responses) defer s.closeCtx() + // NOTE(andrew): we explicitly do not close the 'responses' channel + // when this function exits. If we hit an error and return, we could + // still have outstanding 'handleQuery' goroutines running, and if we + // closed this channel they'd end up trying to send on a closed channel + // when they finish. + // + // Because we call closeCtx, those goroutines will not hang since they + // select on <-s.ctx.Done() as well as s.responses. + for { select { case <-s.readClosing: @@ -476,6 +484,7 @@ func (s *dnsTCPSession) handleQuery(q []byte) { return } + // See note in handleWrites (above) regarding this select{} select { case <-s.ctx.Done(): case s.responses <- resp: @@ -483,6 +492,7 @@ func (s *dnsTCPSession) handleQuery(q []byte) { } func (s *dnsTCPSession) handleReads() { + defer s.conn.Close() defer close(s.readClosing) for { @@ -515,6 +525,11 @@ func (s *dnsTCPSession) handleReads() { case <-s.ctx.Done(): return default: + // NOTE: by kicking off the query handling in a + // new goroutine, it is possible that we'll + // deliver responses out-of-order. This is + // explicitly allowed by RFC7766, Section + // 6.2.1.1 ("Query Pipelining"). go s.handleQuery(buf) } } diff --git a/net/dns/manager_tcp_test.go b/net/dns/manager_tcp_test.go index 5aafb1532..0f886b214 100644 --- a/net/dns/manager_tcp_test.go +++ b/net/dns/manager_tcp_test.go @@ -5,19 +5,23 @@ package dns import ( + "bytes" "encoding/binary" + "errors" "io" "net" "net/netip" "testing" + "time" "github.com/google/go-cmp/cmp" dns "golang.org/x/net/dns/dnsmessage" "tailscale.com/net/tsdial" + "tailscale.com/tstest" "tailscale.com/util/dnsname" ) -func mkDNSRequest(domain dnsname.FQDN, tp dns.Type) []byte { +func mkDNSRequest(domain dnsname.FQDN, tp dns.Type, modify func(*dns.Builder)) []byte { var dnsHeader dns.Header question := dns.Question{ Name: dns.MustNewName(domain.WithTrailingDot()), @@ -37,6 +41,15 @@ func mkDNSRequest(domain dnsname.FQDN, tp dns.Type) []byte { panic(err) } + if modify != nil { + modify(&builder) + } + payload, _ := builder.Finish() + + return payload +} + +func addEDNS(builder *dns.Builder) { ednsHeader := dns.ResourceHeader{ Name: dns.MustNewName("."), Type: dns.TypeOPT, @@ -46,10 +59,25 @@ func mkDNSRequest(domain dnsname.FQDN, tp dns.Type) []byte { if err := builder.OPTResource(ednsHeader, dns.OPTResource{}); err != nil { panic(err) } +} - payload, _ := builder.Finish() +func mkLargeDNSRequest(domain dnsname.FQDN, tp dns.Type) []byte { + return mkDNSRequest(domain, tp, func(builder *dns.Builder) { + ednsHeader := dns.ResourceHeader{ + Name: dns.MustNewName("."), + Type: dns.TypeOPT, + Class: dns.Class(4095), + } - return payload + if err := builder.OPTResource(ednsHeader, dns.OPTResource{ + Options: []dns.Option{{ + Code: 1234, + Data: bytes.Repeat([]byte("A"), maxReqSizeTCP), + }}, + }); err != nil { + panic(err) + } + }) } func TestDNSOverTCP(t *testing.T) { @@ -82,7 +110,7 @@ func TestDNSOverTCP(t *testing.T) { } for domain, _ := range wantResults { - b := mkDNSRequest(domain, dns.TypeA) + b := mkDNSRequest(domain, dns.TypeA, addEDNS) binary.Write(c, binary.BigEndian, uint16(len(b))) c.Write(b) } @@ -134,3 +162,69 @@ func TestDNSOverTCP(t *testing.T) { t.Errorf("wrong results (-got+want)\n%s", diff) } } + +func TestDNSOverTCP_TooLarge(t *testing.T) { + log := tstest.WhileTestRunningLogger(t) + + f := fakeOSConfigurator{ + SplitDNS: true, + BaseConfig: OSConfig{ + Nameservers: mustIPs("8.8.8.8"), + SearchDomains: fqdns("coffee.shop"), + }, + } + m := NewManager(log, &f, nil, new(tsdial.Dialer), nil) + m.resolver.TestOnlySetHook(f.SetResolver) + m.Set(Config{ + Hosts: hosts("andrew.ts.com.", "1.2.3.4"), + Routes: upstreams("ts.com", ""), + SearchDomains: fqdns("tailscale.com"), + }) + defer m.Down() + + c, s := net.Pipe() + defer s.Close() + go m.HandleTCPConn(s, netip.AddrPort{}) + defer c.Close() + + var b []byte + domain := dnsname.FQDN("andrew.ts.com.") + + // Write a successful request, then a large one that will fail; this + // exercises the data race in tailscale/tailscale#6725 + b = mkDNSRequest(domain, dns.TypeA, addEDNS) + binary.Write(c, binary.BigEndian, uint16(len(b))) + if _, err := c.Write(b); err != nil { + t.Fatal(err) + } + + c.SetWriteDeadline(time.Now().Add(5 * time.Second)) + + b = mkLargeDNSRequest(domain, dns.TypeA) + if err := binary.Write(c, binary.BigEndian, uint16(len(b))); err != nil { + t.Fatal(err) + } + if _, err := c.Write(b); err != nil { + // It's possible that we get an error here, since the + // net.Pipe() implementation enforces synchronous reads. So, + // handleReads could read the size, then error, and this write + // fails. That's actually a success for this test! + if errors.Is(err, io.ErrClosedPipe) { + t.Logf("pipe (correctly) closed when writing large response") + return + } + + t.Fatal(err) + } + + t.Logf("reading responses") + c.SetReadDeadline(time.Now().Add(5 * time.Second)) + + // We expect an EOF now, since the connection will have been closed due + // to a too-large query. + var respLength uint16 + err := binary.Read(c, binary.BigEndian, &respLength) + if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrClosedPipe) { + t.Errorf("expected EOF on large read; got %v", err) + } +}