From d1d6ab068ebd91de6688367ffbc33206631cba24 Mon Sep 17 00:00:00 2001 From: Tom Date: Thu, 5 May 2022 16:42:45 -0700 Subject: [PATCH] net/dns, wgengine: implement DNS over TCP (#4598) * net/dns, wgengine: implement DNS over TCP Signed-off-by: Tom DNetto * wgengine/netstack: intercept only relevant port/protocols to quad-100 Signed-off-by: Tom DNetto --- net/dns/manager.go | 119 +++++++++++++++++++++++++++++ net/dns/manager_tcp_test.go | 136 ++++++++++++++++++++++++++++++++++ wgengine/netstack/netstack.go | 18 +++++ 3 files changed, 273 insertions(+) create mode 100644 net/dns/manager_tcp_test.go diff --git a/net/dns/manager.go b/net/dns/manager.go index 36040a7e8..9b0fef4a1 100644 --- a/net/dns/manager.go +++ b/net/dns/manager.go @@ -7,7 +7,9 @@ package dns import ( "bufio" "context" + "encoding/binary" "errors" + "io" "net" "runtime" "sync/atomic" @@ -346,6 +348,123 @@ func (m *Manager) Query(ctx context.Context, bs []byte, from netaddr.IPPort) ([] return m.resolver.Query(ctx, bs, from) } +const ( + // RFC 7766 6.2 recommends connection reuse & request pipelining + // be undertaken, and the connection be closed by the server + // using an idle timeout on the order of seconds. + idleTimeoutTCP = 45 * time.Second + // The RFCs don't specify the max size of a TCP-based DNS query, + // but we want to keep this reasonable. Given payloads are typically + // much larger and all known client send a single query, I've arbitrarily + // chosen 2k. + maxReqSizeTCP = 2048 +) + +// dnsTCPSession services DNS requests sent over TCP. +type dnsTCPSession struct { + m *Manager + + conn net.Conn + srcAddr netaddr.IPPort + + readClosing chan struct{} + responses chan []byte // DNS replies pending writing + + ctx context.Context + closeCtx context.CancelFunc +} + +func (s *dnsTCPSession) handleWrites() { + defer s.conn.Close() + defer close(s.responses) + defer s.closeCtx() + + for { + select { + case <-s.readClosing: + return // connection closed or timeout, teardown time + + case resp := <-s.responses: + s.conn.SetWriteDeadline(time.Now().Add(idleTimeoutTCP)) + if err := binary.Write(s.conn, binary.BigEndian, uint16(len(resp))); err != nil { + s.m.logf("tcp write (len): %v", err) + return + } + if _, err := s.conn.Write(resp); err != nil { + s.m.logf("tcp write (response): %v", err) + return + } + } + } +} + +func (s *dnsTCPSession) handleQuery(q []byte) { + resp, err := s.m.Query(s.ctx, q, s.srcAddr) + if err != nil { + s.m.logf("tcp query: %v", err) + return + } + + select { + case <-s.ctx.Done(): + case s.responses <- resp: + } +} + +func (s *dnsTCPSession) handleReads() { + defer close(s.readClosing) + + for { + select { + case <-s.ctx.Done(): + return + + default: + s.conn.SetReadDeadline(time.Now().Add(idleTimeoutTCP)) + var reqLen uint16 + if err := binary.Read(s.conn, binary.BigEndian, &reqLen); err != nil { + if err == io.EOF || err == io.ErrClosedPipe { + return // connection closed nominally, we gucci + } + s.m.logf("tcp read (len): %v", err) + return + } + if int(reqLen) > maxReqSizeTCP { + s.m.logf("tcp request too large (%d > %d)", reqLen, maxReqSizeTCP) + return + } + + buf := make([]byte, int(reqLen)) + if _, err := io.ReadFull(s.conn, buf); err != nil { + s.m.logf("tcp read (payload): %v", err) + return + } + + select { + case <-s.ctx.Done(): + return + default: + go s.handleQuery(buf) + } + } + } +} + +// HandleTCPConn implements magicDNS over TCP, taking a connection and +// servicing DNS requests sent down it. +func (m *Manager) HandleTCPConn(conn net.Conn, srcAddr netaddr.IPPort) { + s := dnsTCPSession{ + m: m, + conn: conn, + srcAddr: srcAddr, + responses: make(chan []byte), + readClosing: make(chan struct{}), + } + s.ctx, s.closeCtx = context.WithCancel(context.Background()) + go s.handleReads() + s.handleWrites() +} + func (m *Manager) Down() error { m.ctxCancel() if err := m.os.Close(); err != nil { diff --git a/net/dns/manager_tcp_test.go b/net/dns/manager_tcp_test.go new file mode 100644 index 000000000..26b969053 --- /dev/null +++ b/net/dns/manager_tcp_test.go @@ -0,0 +1,136 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dns + +import ( + "encoding/binary" + "io" + "net" + "testing" + + "github.com/google/go-cmp/cmp" + dns "golang.org/x/net/dns/dnsmessage" + "inet.af/netaddr" + "tailscale.com/net/tsdial" + "tailscale.com/util/dnsname" +) + +func mkDNSRequest(domain dnsname.FQDN, tp dns.Type) []byte { + var dnsHeader dns.Header + question := dns.Question{ + Name: dns.MustNewName(domain.WithTrailingDot()), + Type: tp, + Class: dns.ClassINET, + } + + builder := dns.NewBuilder(nil, dnsHeader) + if err := builder.StartQuestions(); err != nil { + panic(err) + } + if err := builder.Question(question); err != nil { + panic(err) + } + + if err := builder.StartAdditionals(); err != nil { + panic(err) + } + + ednsHeader := dns.ResourceHeader{ + Name: dns.MustNewName("."), + Type: dns.TypeOPT, + Class: dns.Class(4095), + } + + if err := builder.OPTResource(ednsHeader, dns.OPTResource{}); err != nil { + panic(err) + } + + payload, _ := builder.Finish() + + return payload +} + +func TestDNSOverTCP(t *testing.T) { + f := fakeOSConfigurator{ + SplitDNS: true, + BaseConfig: OSConfig{ + Nameservers: mustIPs("8.8.8.8"), + SearchDomains: fqdns("coffee.shop"), + }, + } + m := NewManager(t.Logf, &f, nil, new(tsdial.Dialer), nil) + m.resolver.TestOnlySetHook(f.SetResolver) + m.Set(Config{ + Hosts: hosts( + "dave.ts.com.", "1.2.3.4", + "bradfitz.ts.com.", "2.3.4.5"), + Routes: upstreams("ts.com", ""), + SearchDomains: fqdns("tailscale.com", "universe.tf"), + }) + defer m.Down() + + c, s := net.Pipe() + defer s.Close() + go m.HandleTCPConn(s, netaddr.IPPort{}) + defer c.Close() + + wantResults := map[dnsname.FQDN]string{ + "dave.ts.com.": "1.2.3.4", + "bradfitz.ts.com.": "2.3.4.5", + } + + for domain, _ := range wantResults { + b := mkDNSRequest(domain, dns.TypeA) + binary.Write(c, binary.BigEndian, uint16(len(b))) + c.Write(b) + } + + results := map[dnsname.FQDN]string{} + for i := 0; i < len(wantResults); i++ { + var respLength uint16 + if err := binary.Read(c, binary.BigEndian, &respLength); err != nil { + t.Fatalf("reading len: %v", err) + } + resp := make([]byte, int(respLength)) + if _, err := io.ReadFull(c, resp); err != nil { + t.Fatalf("reading data: %v", err) + } + + var parser dns.Parser + if _, err := parser.Start(resp); err != nil { + t.Errorf("parser.Start() failed: %v", err) + continue + } + q, err := parser.Question() + if err != nil { + t.Errorf("parser.Question(): %v", err) + continue + } + if err := parser.SkipAllQuestions(); err != nil { + t.Errorf("parser.SkipAllQuestions(): %v", err) + continue + } + ah, err := parser.AnswerHeader() + if err != nil { + t.Errorf("parser.AnswerHeader(): %v", err) + continue + } + if ah.Type != dns.TypeA { + t.Errorf("unexpected answer type: got %v, want %v", ah.Type, dns.TypeA) + continue + } + res, err := parser.AResource() + if err != nil { + t.Errorf("parser.AResource(): %v", err) + continue + } + results[dnsname.FQDN(q.Name.String())] = net.IP(res.A[:]).String() + } + c.Close() + + if diff := cmp.Diff(wantResults, results); diff != "" { + t.Errorf("wrong results (-got+want)\n%s", diff) + } +} diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 2c9766ddc..03e03383a 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -373,6 +373,19 @@ func (ns *Impl) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper) filter.Re if dst := p.Dst.IP(); dst != magicDNSIP && dst != magicDNSIPv6 { return filter.Accept } + // Of traffic to the service IP, we only care about UDP 53, and TCP + // on port 80 & 53. + switch p.IPProto { + case ipproto.TCP: + if port := p.Dst.Port(); port != 53 && port != 80 { + return filter.Accept + } + case ipproto.UDP: + if port := p.Dst.Port(); port != 53 { + return filter.Accept + } + } + var pn tcpip.NetworkProtocolNumber switch p.IPVersion { @@ -758,6 +771,11 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { // block until the TCP handshake is complete. c := gonet.NewTCPConn(&wq, ep) + if reqDetails.LocalPort == 53 && (dialIP == magicDNSIP || dialIP == magicDNSIPv6) { + go ns.dns.HandleTCPConn(c, netaddr.IPPortFrom(clientRemoteIP, reqDetails.RemotePort)) + return + } + if ns.lb != nil { if reqDetails.LocalPort == 22 && ns.processSSH() && ns.isLocalIP(dialIP) { if err := ns.lb.HandleSSHConn(c); err != nil {