From 511840b1f6f955833a66b7979849478200d064da Mon Sep 17 00:00:00 2001 From: Dmytro Shynkevych Date: Mon, 8 Jun 2020 18:19:26 -0400 Subject: [PATCH] tsdns: initial implementation of a Tailscale DNS resolver (#396) Signed-off-by: Dmytro Shynkevych --- ipn/local.go | 31 ++++ wgengine/filter/filter.go | 41 +++-- wgengine/filter/filter_test.go | 18 ++- wgengine/packet/packet.go | 10 +- wgengine/packet/packet_test.go | 2 + wgengine/tsdns/tsdns.go | 274 +++++++++++++++++++++++++++++++++ wgengine/tstun/tun.go | 167 +++++++++++++++----- wgengine/tstun/tun_test.go | 6 +- wgengine/userspace.go | 124 ++++++++++++--- wgengine/watchdog.go | 4 + wgengine/watchdog_test.go | 9 +- wgengine/wgengine.go | 4 + 12 files changed, 582 insertions(+), 108 deletions(-) create mode 100644 wgengine/tsdns/tsdns.go diff --git a/ipn/local.go b/ipn/local.go index 97c45171b..5f2af42e1 100644 --- a/ipn/local.go +++ b/ipn/local.go @@ -26,6 +26,7 @@ import ( "tailscale.com/wgengine" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/router" + "tailscale.com/wgengine/tsdns" ) // LocalBackend is the glue between the major pieces of the Tailscale @@ -311,6 +312,7 @@ func (b *LocalBackend) Start(opts Options) error { b.send(Notify{NetMap: newSt.NetMap}) b.updateFilter(newSt.NetMap) + b.updateDNSMap(newSt.NetMap) if disableDERP { b.e.SetDERPMap(nil) } else { @@ -427,6 +429,27 @@ func (b *LocalBackend) updateFilter(netMap *controlclient.NetworkMap) { b.e.SetFilter(filter.New(netMap.PacketFilter, localNets, b.e.GetFilter(), b.logf)) } +// updateDNSMap updates the domain map in the DNS resolver in wgengine +// based on the given netMap and user preferences. +func (b *LocalBackend) updateDNSMap(netMap *controlclient.NetworkMap) { + if netMap == nil { + return + } + dnsMap := &tsdns.Map{DomainToIP: make(map[string]netaddr.IP)} + for _, peer := range netMap.Peers { + if len(peer.Addresses) == 0 { + continue + } + domain := peer.Hostinfo.Hostname + // Like PeerStatus.SimpleHostName() + domain = strings.TrimSuffix(domain, ".local") + domain = strings.TrimSuffix(domain, ".localdomain") + domain = domain + ".ipn.dev" + dnsMap.DomainToIP[domain] = netaddr.IPFrom16(peer.Addresses[0].IP.Addr) + } + b.e.SetDNSMap(dnsMap) +} + // readPoller is a goroutine that receives service lists from // b.portpoll and propagates them into the controlclient's HostInfo. func (b *LocalBackend) readPoller() { @@ -667,6 +690,7 @@ func (b *LocalBackend) SetPrefs(new *Prefs) { } b.updateFilter(b.netMapCache) + b.updateDNSMap(b.netMapCache) if old.WantRunning != new.WantRunning { b.stateMachine() @@ -799,6 +823,13 @@ func routerConfig(cfg *wgcfg.Config, prefs *Prefs, dnsDomains []string) *router. rs.Routes = append(rs.Routes, wgCIDRToNetaddr(peer.AllowedIPs)...) } + // The Tailscale DNS IP. + // TODO(dmytro): make this configurable. + rs.Routes = append(rs.Routes, netaddr.IPPrefix{ + IP: netaddr.IPv4(100, 100, 100, 100), + Bits: 32, + }) + return rs } diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index 2945dbe1b..007960777 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -6,7 +6,6 @@ package filter import ( - "fmt" "sync" "time" @@ -137,7 +136,7 @@ func maybeHexdump(flag RunFlags, b []byte) string { var acceptBucket = rate.NewLimiter(rate.Every(10*time.Second), 3) var dropBucket = rate.NewLimiter(rate.Every(5*time.Second), 10) -func (f *Filter) logRateLimit(runflags RunFlags, b []byte, q *packet.ParsedPacket, r Response, why string) { +func (f *Filter) logRateLimit(runflags RunFlags, q *packet.ParsedPacket, r Response, why string) { var verdict string if r == Drop && (runflags&LogDrops) != 0 && dropBucket.Allow() { @@ -151,36 +150,33 @@ func (f *Filter) logRateLimit(runflags RunFlags, b []byte, q *packet.ParsedPacke // Note: it is crucial that q.String() be called only if {accept,drop}Bucket.Allow() passes, // since it causes an allocation. if verdict != "" { - var qs string - if q == nil { - qs = fmt.Sprintf("(%d bytes)", len(b)) - } else { - qs = q.String() - } - f.logf("%s: %s %d %s\n%s", verdict, qs, len(b), why, maybeHexdump(runflags, b)) + b := q.Buffer() + f.logf("%s: %s %d %s\n%s", verdict, q.String(), len(b), why, maybeHexdump(runflags, b)) } } -func (f *Filter) RunIn(b []byte, q *packet.ParsedPacket, rf RunFlags) Response { - r := f.pre(b, q, rf) +// RunIn determines whether this node is allowed to receive q from a Tailscale peer. +func (f *Filter) RunIn(q *packet.ParsedPacket, rf RunFlags) Response { + r := f.pre(q, rf) if r == Accept || r == Drop { // already logged return r } r, why := f.runIn(q) - f.logRateLimit(rf, b, q, r, why) + f.logRateLimit(rf, q, r, why) return r } -func (f *Filter) RunOut(b []byte, q *packet.ParsedPacket, rf RunFlags) Response { - r := f.pre(b, q, rf) +// RunOut determines whether this node is allowed to send q to a Tailscale peer. +func (f *Filter) RunOut(q *packet.ParsedPacket, rf RunFlags) Response { + r := f.pre(q, rf) if r == Drop || r == Accept { // already logged return r } r, why := f.runOut(q) - f.logRateLimit(rf, b, q, r, why) + f.logRateLimit(rf, q, r, why) return r } @@ -251,29 +247,28 @@ func (f *Filter) runOut(q *packet.ParsedPacket) (r Response, why string) { return Accept, "ok out" } -func (f *Filter) pre(b []byte, q *packet.ParsedPacket, rf RunFlags) Response { - if len(b) == 0 { +func (f *Filter) pre(q *packet.ParsedPacket, rf RunFlags) Response { + if len(q.Buffer()) == 0 { // wireguard keepalive packet, always permit. return Accept } - if len(b) < 20 { - f.logRateLimit(rf, b, nil, Drop, "too short") + if len(q.Buffer()) < 20 { + f.logRateLimit(rf, q, Drop, "too short") return Drop } - q.Decode(b) switch q.IPProto { case packet.Unknown: // Unknown packets are dangerous; always drop them. - f.logRateLimit(rf, b, q, Drop, "unknown") + f.logRateLimit(rf, q, Drop, "unknown") return Drop case packet.IPv6: - f.logRateLimit(rf, b, q, Drop, "ipv6") + f.logRateLimit(rf, q, Drop, "ipv6") return Drop case packet.Fragment: // Fragments after the first always need to be passed through. // Very small fragments are considered Junk by ParsedPacket. - f.logRateLimit(rf, b, q, Accept, "fragment") + f.logRateLimit(rf, q, Accept, "fragment") return Accept } diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index fc0fef231..783b258f0 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -144,11 +144,12 @@ func TestNoAllocs(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { got := int(testing.AllocsPerRun(1000, func() { - var q ParsedPacket + q := &ParsedPacket{} + q.Decode(test.packet) if test.in { - acl.RunIn(test.packet, &q, 0) + acl.RunIn(q, 0) } else { - acl.RunOut(test.packet, &q, 0) + acl.RunOut(q, 0) } })) @@ -187,12 +188,13 @@ func BenchmarkFilter(b *testing.B) { for _, bench := range benches { b.Run(bench.name, func(b *testing.B) { for i := 0; i < b.N; i++ { - var q ParsedPacket + q := &ParsedPacket{} + q.Decode(bench.packet) // This branch seems to have no measurable impact on performance. if bench.in { - acl.RunIn(bench.packet, &q, 0) + acl.RunIn(q, 0) } else { - acl.RunOut(bench.packet, &q, 0) + acl.RunOut(q, 0) } } }) @@ -215,7 +217,9 @@ func TestPreFilter(t *testing.T) { } f := NewAllowNone(t.Logf) for _, testPacket := range packets { - got := f.pre([]byte(testPacket.b), &ParsedPacket{}, LogDrops|LogAccepts) + p := &ParsedPacket{} + p.Decode(testPacket.b) + got := f.pre(p, LogDrops|LogAccepts) if got != testPacket.want { t.Errorf("%q got=%v want=%v packet:\n%s", testPacket.desc, got, testPacket.want, packet.Hexdump(testPacket.b)) } diff --git a/wgengine/packet/packet.go b/wgengine/packet/packet.go index 92bc1eda8..69f7c99a2 100644 --- a/wgengine/packet/packet.go +++ b/wgengine/packet/packet.go @@ -102,7 +102,7 @@ func ipChecksum(b []byte) uint16 { // It extracts only the subprotocol id, IP addresses, and (if any) ports, // and shouldn't need any memory allocation. func (q *ParsedPacket) Decode(b []byte) { - q.b = nil + q.b = b if len(b) < ipHeaderLength { q.IPProto = Unknown @@ -170,7 +170,6 @@ func (q *ParsedPacket) Decode(b []byte) { } q.SrcPort = 0 q.DstPort = 0 - q.b = b q.dataofs = q.subofs + icmpHeaderLength return case TCP: @@ -181,7 +180,6 @@ func (q *ParsedPacket) Decode(b []byte) { q.SrcPort = get16(sub[0:2]) q.DstPort = get16(sub[2:4]) q.TCPFlags = sub[13] & 0x3F - q.b = b headerLength := (sub[12] & 0xF0) >> 2 q.dataofs = q.subofs + int(headerLength) return @@ -192,7 +190,6 @@ func (q *ParsedPacket) Decode(b []byte) { } q.SrcPort = get16(sub[0:2]) q.DstPort = get16(sub[2:4]) - q.b = b q.dataofs = q.subofs + udpHeaderLength return default: @@ -244,6 +241,11 @@ func (q *ParsedPacket) UDPHeader() UDPHeader { } } +// Buffer returns the entire packet buffer. +func (q *ParsedPacket) Buffer() []byte { + return q.b +} + // Sub returns the IP subprotocol section. func (q *ParsedPacket) Sub(begin, n int) []byte { return q.b[q.subofs+begin : q.subofs+begin+n] diff --git a/wgengine/packet/packet_test.go b/wgengine/packet/packet_test.go index 11f75e1e3..7bb88b213 100644 --- a/wgengine/packet/packet_test.go +++ b/wgengine/packet/packet_test.go @@ -90,6 +90,7 @@ var ipv6PacketBuffer = []byte{ } var ipv6PacketDecode = ParsedPacket{ + b: ipv6PacketBuffer, IPProto: IPv6, } @@ -100,6 +101,7 @@ var unknownPacketBuffer = []byte{ } var unknownPacketDecode = ParsedPacket{ + b: unknownPacketBuffer, IPProto: Unknown, } diff --git a/wgengine/tsdns/tsdns.go b/wgengine/tsdns/tsdns.go new file mode 100644 index 000000000..8c0efdf56 --- /dev/null +++ b/wgengine/tsdns/tsdns.go @@ -0,0 +1,274 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package tsdns provides a Resolver struct capable of resolving +// domains on a Tailscale network. +package tsdns + +import ( + "encoding/binary" + "errors" + "strings" + "sync" + + dns "golang.org/x/net/dns/dnsmessage" + "inet.af/netaddr" + "tailscale.com/types/logger" + "tailscale.com/wgengine/packet" +) + +// defaultTTL is the TTL in seconds of all responses from Resolver. +const defaultTTL = 600 + +var ( + errMapNotSet = errors.New("domain map not set") + errNoSuchDomain = errors.New("domain does not exist") + errNotImplemented = errors.New("query type not implemented") + errNotOurName = errors.New("not an *.ipn.dev domain") + errNotQuery = errors.New("not a DNS query") +) + +var ( + defaultIP = packet.IP(binary.BigEndian.Uint32([]byte{100, 100, 100, 100})) + defaultPort = uint16(53) +) + +// Map is all the data Resolver needs to resolve DNS queries. +type Map struct { + // DomainToIP is a mapping of Tailscale domains to their IP addresses. + // For example, monitoring.ipn.dev -> 100.64.0.1. + DomainToIP map[string]netaddr.IP +} + +// Resolver is a DNS resolver for domain names of the form *.ipn.dev +// It is intended +type Resolver struct { + logf logger.Logf + + // ip is the IP on which the resolver is listening. + ip packet.IP + // port is the port on which the resolver is listening. + port uint16 + + // mu guards the following fields from being updated while used. + mu sync.Mutex + // dnsMap is the map most recently received from the control server. + dnsMap *Map +} + +// NewResolver constructs a resolver with default parameters. +func NewResolver(logf logger.Logf) *Resolver { + r := &Resolver{ + logf: logf, + ip: defaultIP, + port: defaultPort, + } + + return r +} + +// AcceptsPacket determines if the given packet is +// directed to this resolver (by ip and port). +// We also require that UDP be used to simplify things for now. +func (r *Resolver) AcceptsPacket(in *packet.ParsedPacket) bool { + return in.DstIP == r.ip && in.DstPort == r.port && in.IPProto == packet.UDP +} + +// SetMap sets the resolver's DNS map. +func (r *Resolver) SetMap(m *Map) { + r.mu.Lock() + r.dnsMap = m + r.mu.Unlock() +} + +// Resolve maps a given domain name to the IP address of the host that owns it. +func (r *Resolver) Resolve(domain string) (netaddr.IP, dns.RCode, error) { + // If not a subdomain of ipn.dev, then we must refuse this query. + // We do this before checking the map to distinguish beween nonexistent domains + // and misdirected queries. + if !strings.HasSuffix(domain, ".ipn.dev") { + return netaddr.IP{}, dns.RCodeRefused, errNotOurName + } + + r.mu.Lock() + if r.dnsMap == nil { + r.mu.Unlock() + return netaddr.IP{}, dns.RCodeServerFailure, errMapNotSet + } + addr, found := r.dnsMap.DomainToIP[domain] + r.mu.Unlock() + + if !found { + return netaddr.IP{}, dns.RCodeNameError, errNoSuchDomain + } + return addr, dns.RCodeSuccess, nil +} + +type response struct { + Header dns.Header + ResourceHeader dns.ResourceHeader + Question dns.Question + // TODO(dmytro): support IPv6. + IP netaddr.IP +} + +// parseQuery parses the query in given packet into a response struct. +func (r *Resolver) parseQuery(query *packet.ParsedPacket, resp *response) error { + var parser dns.Parser + var err error + + resp.Header, err = parser.Start(query.Payload()) + if err != nil { + return err + } + + if resp.Header.Response { + return errNotQuery + } + + resp.Question, err = parser.Question() + if err != nil { + return err + } + + return nil +} + +// makeResponse resolves the question stored in resp and sets the answer fields. +func (r *Resolver) makeResponse(resp *response) error { + var err error + + name := resp.Question.Name.String() + if len(name) > 0 { + name = name[:len(name)-1] + } + + if resp.Question.Type == dns.TypeA { + // Remove final dot from name: *.ipn.dev. -> *.ipn.dev + resp.IP, resp.Header.RCode, err = r.Resolve(name) + } else { + resp.Header.RCode = dns.RCodeNotImplemented + err = errNotImplemented + } + + return err +} + +// marshalAnswer serializes the answer record into an active builder. +func marshalAnswer(resp *response, builder *dns.Builder) error { + var answer dns.AResource + + err := builder.StartAnswers() + if err != nil { + return err + } + + answerHeader := dns.ResourceHeader{ + Name: resp.Question.Name, + Type: dns.TypeA, + Class: dns.ClassINET, + TTL: defaultTTL, + } + ip := resp.IP.As16() + copy(answer.A[:], ip[12:]) + return builder.AResource(answerHeader, answer) +} + +// marshalResponse serializes the DNS response into an active builder. +func marshalResponse(resp *response, builder *dns.Builder) ([]byte, error) { + resp.Header.Response = true + resp.Header.Authoritative = true + if resp.Header.RecursionDesired { + resp.Header.RecursionAvailable = true + } + + err := builder.StartQuestions() + if err != nil { + return nil, err + } + + err = builder.Question(resp.Question) + if err != nil { + return nil, err + } + + if resp.Header.RCode == dns.RCodeSuccess { + err = marshalAnswer(resp, builder) + if err != nil { + return nil, err + } + } + + return builder.Finish() +} + +func marshalResponsePacket(query *packet.ParsedPacket, resp *response, buf []byte) ([]byte, error) { + udpHeader := query.UDPHeader() + udpHeader.ToResponse() + offset := udpHeader.Len() + + // dns.Builder appends to the passed buffer (without reallocation when possible), + // so we pass in a zero-length slice starting at the point it should start writing. + builder := dns.NewBuilder(buf[offset:offset], resp.Header) + + // rbuf is the response slice with the correct length starting at offset. + rbuf, err := marshalResponse(resp, &builder) + if err != nil { + return nil, err + } + + end := offset + len(rbuf) + err = udpHeader.Marshal(buf[:end]) + if err != nil { + return nil, err + } + + return buf[:end], nil +} + +// Respond writes a response to query into buf and returns buf trimmed to the response length. +// It is assumed that r.AcceptsPacket(query) is true. +func (r *Resolver) Respond(query *packet.ParsedPacket, buf []byte) ([]byte, error) { + var resp response + var err error + + // 0. Verify that contract is upheld. + if !r.AcceptsPacket(query) { + r.logf("[unexpected] tsdns: Respond called on query not for this resolver") + resp.Header.RCode = dns.RCodeServerFailure + return marshalResponsePacket(query, &resp, buf) + } + // A DNS response is at least as long as the query + if len(buf) < len(query.Buffer()) { + r.logf("[unexpected] tsdns: response buffer is too small") + resp.Header.RCode = dns.RCodeServerFailure + return marshalResponsePacket(query, &resp, buf) + } + + // 1. Parse query packet. + err = r.parseQuery(query, &resp) + // We will not return this error: it is the sender's fault. + if err != nil { + r.logf("tsdns: error during query parsing: %v", err) + resp.Header.RCode = dns.RCodeFormatError + return marshalResponsePacket(query, &resp, buf) + } + + // 2. Service the query. + err = r.makeResponse(&resp) + // We will not return this error: it is the sender's fault. + if err != nil { + r.logf("tsdns: error during name resolution: %v", err) + return marshalResponsePacket(query, &resp, buf) + } + // For now, we require IPv4 in all cases. + // If we somehow came up with a non-IPv4 address, it's our fault. + if !resp.IP.Is4() { + resp.Header.RCode = dns.RCodeServerFailure + r.logf("tsdns: error during name resolution: IPv6 address: %v", resp.IP) + } + + // 3. Serialize the response. + return marshalResponsePacket(query, &resp, buf) +} diff --git a/wgengine/tstun/tun.go b/wgengine/tstun/tun.go index 90457d96c..41393b7c3 100644 --- a/wgengine/tstun/tun.go +++ b/wgengine/tstun/tun.go @@ -10,6 +10,7 @@ import ( "errors" "io" "os" + "sync" "sync/atomic" "github.com/tailscale/wireguard-go/device" @@ -19,10 +20,12 @@ import ( "tailscale.com/wgengine/packet" ) -const ( - readMaxSize = device.MaxMessageSize - readOffset = device.MessageTransportHeaderSize -) +const maxBufferSize = device.MaxMessageSize + +// PacketStartOffset is the minimal amount of leading space that must exist +// before &packet[offset] in a packet passed to Read, Write, or InjectInboundDirect. +// This is necessary to avoid reallocation in wireguard-go internals. +const PacketStartOffset = device.MessageTransportHeaderSize // MaxPacketSize is the maximum size (in bytes) // of a packet that can be injected into a tstun.TUN. @@ -35,7 +38,15 @@ var ( ErrFiltered = errors.New("packet dropped by filter") ) -var errPacketTooBig = errors.New("packet too big") +var ( + errPacketTooBig = errors.New("packet too big") + errOffsetTooBig = errors.New("offset larger than buffer length") + errOffsetTooSmall = errors.New("offset smaller than PacketStartOffset") +) + +// FilterFunc is a packet-filtering function with access to the TUN device. +// It must not hold onto the packet struct, as its backing storage will be reused. +type FilterFunc func(*packet.ParsedPacket, *TUN) filter.Response // TUN wraps a tun.Device from wireguard-go, // augmenting it with filtering and packet injection. @@ -47,10 +58,14 @@ type TUN struct { tdev tun.Device // buffer stores the oldest unconsumed packet from tdev. - // It is made a static buffer in order to avoid graticious allocation. - buffer [readMaxSize]byte + // It is made a static buffer in order to avoid allocations. + buffer [maxBufferSize]byte // bufferConsumed synchronizes access to buffer (shared by Read and poll). bufferConsumed chan struct{} + // parsedPacketPool holds a pool of ParsedPacket structs for use in filtering. + // This is needed because escape analysis cannot see that parsed packets + // do not escape through {Pre,Post}Filter{In,Out}. + parsedPacketPool sync.Pool // of *packet.ParsedPacket // closed signals poll (by closing) when the device is closed. closed chan struct{} @@ -73,8 +88,19 @@ type TUN struct { // filterFlags control the verbosity of logging packet drops/accepts. filterFlags filter.RunFlags - // insecure disables all filtering when set. This is useful in tests. - insecure bool + // PreFilterIn is the inbound filter function that runs before the main filter + // and therefore sees the packets that may be later dropped by it. + PreFilterIn FilterFunc + // PostFilterIn is the inbound filter function that runs after the main filter. + PostFilterIn FilterFunc + // PreFilterOut is the outbound filter function that runs before the main filter + // and therefore sees the packets that may be later dropped by it. + PreFilterOut FilterFunc + // PostFilterOut is the outbound filter function that runs after the main filter. + PostFilterOut FilterFunc + + // disableFilter disables all filtering when set. This should only be used in tests. + disableFilter bool } func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN { @@ -87,8 +113,14 @@ func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN { closed: make(chan struct{}), errors: make(chan error), outbound: make(chan []byte), - filterFlags: filter.LogAccepts | filter.LogDrops, + // TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets. + filterFlags: filter.LogAccepts | filter.LogDrops, } + + tun.parsedPacketPool.New = func() interface{} { + return new(packet.ParsedPacket) + } + go tun.poll() // The buffer starts out consumed. tun.bufferConsumed <- struct{}{} @@ -140,10 +172,10 @@ func (t *TUN) poll() { // continue } - // Read may use memory in t.buffer before readOffset for mandatory headers. + // Read may use memory in t.buffer before PacketStartOffset for mandatory headers. // This is the rationale behind the tun.TUN.{Read,Write} interfaces // and the reason t.buffer has size MaxMessageSize and not MaxContentSize. - n, err := t.tdev.Read(t.buffer[:], readOffset) + n, err := t.tdev.Read(t.buffer[:], PacketStartOffset) if err != nil { select { case <-t.closed: @@ -165,26 +197,41 @@ func (t *TUN) poll() { select { case <-t.closed: return - case t.outbound <- t.buffer[readOffset : readOffset+n]: + case t.outbound <- t.buffer[PacketStartOffset : PacketStartOffset+n]: // continue } } } func (t *TUN) filterOut(buf []byte) filter.Response { + p := t.parsedPacketPool.Get().(*packet.ParsedPacket) + defer t.parsedPacketPool.Put(p) + p.Decode(buf) + + if t.PreFilterOut != nil { + if t.PreFilterOut(p, t) == filter.Drop { + return filter.Drop + } + } + filt, _ := t.filter.Load().(*filter.Filter) if filt == nil { - t.logf("Warning: you forgot to use SetFilter()! Packet dropped.") + t.logf("tstun: warning: you forgot to use SetFilter()! Packet dropped.") return filter.Drop } - var p packet.ParsedPacket - if filt.RunOut(buf, &p, t.filterFlags) == filter.Accept { - return filter.Accept + if filt.RunOut(p, t.filterFlags) != filter.Accept { + return filter.Drop } - return filter.Drop + if t.PostFilterOut != nil { + if t.PostFilterOut(p, t) == filter.Drop { + return filter.Drop + } + } + + return filter.Accept } func (t *TUN) Read(buf []byte, offset int) (int, error) { @@ -200,12 +247,16 @@ func (t *TUN) Read(buf []byte, offset int) (int, error) { // t.buffer has a fixed location in memory, // so this is the easiest way to tell when it has been consumed. // &packet[0] can be used because empty packets do not reach t.outbound. - if &packet[0] == &t.buffer[readOffset] { + if &packet[0] == &t.buffer[PacketStartOffset] { t.bufferConsumed <- struct{}{} + } else { + // If the packet is not from t.buffer, then it is an injected packet. + // In this case, we return eary to bypass filtering + return n, nil } } - if !t.insecure { + if !t.disableFilter { response := t.filterOut(buf[offset : offset+n]) if response != filter.Accept { // Wireguard considers read errors fatal; pretend nothing was read @@ -217,35 +268,38 @@ func (t *TUN) Read(buf []byte, offset int) (int, error) { } func (t *TUN) filterIn(buf []byte) filter.Response { + p := t.parsedPacketPool.Get().(*packet.ParsedPacket) + defer t.parsedPacketPool.Put(p) + p.Decode(buf) + + if t.PreFilterIn != nil { + if t.PreFilterIn(p, t) == filter.Drop { + return filter.Drop + } + } + filt, _ := t.filter.Load().(*filter.Filter) if filt == nil { - t.logf("Warning: you forgot to use SetFilter()! Packet dropped.") + t.logf("tstun: warning: you forgot to use SetFilter()! Packet dropped.") return filter.Drop } - var p packet.ParsedPacket - if filt.RunIn(buf, &p, t.filterFlags) == filter.Accept { - // Only in fake mode, answer any incoming pings. - if p.IsEchoRequest() { - ft, ok := t.tdev.(*fakeTUN) - if ok { - header := p.ICMPHeader() - header.ToResponse() - packet := packet.Generate(&header, p.Payload()) - ft.Write(packet, 0) - // We already handled it, stop. - return filter.Drop - } + if filt.RunIn(p, t.filterFlags) != filter.Accept { + return filter.Drop + } + + if t.PostFilterIn != nil { + if t.PostFilterIn(p, t) == filter.Drop { + return filter.Drop } - return filter.Accept } - return filter.Drop + return filter.Accept } func (t *TUN) Write(buf []byte, offset int) (int, error) { - if !t.insecure { + if !t.disableFilter { response := t.filterIn(buf[offset:]) if response != filter.Accept { return 0, ErrFiltered @@ -264,24 +318,53 @@ func (t *TUN) SetFilter(filt *filter.Filter) { t.filter.Store(filt) } -// InjectInbound makes the TUN device behave as if a packet +// InjectInboundDirect makes the TUN device behave as if a packet // with the given contents was received from the network. // It blocks and does not take ownership of the packet. -// Injecting an empty packet is a no-op. -func (t *TUN) InjectInbound(packet []byte) error { +// The injected packet will not pass through inbound filters. +// +// The packet contents are to start at &buf[offset]. +// offset must be greater or equal to PacketStartOffset. +// The space before &buf[offset] will be used by Wireguard. +func (t *TUN) InjectInboundDirect(buf []byte, offset int) error { + if len(buf) > MaxPacketSize { + return errPacketTooBig + } + if len(buf) < offset { + return errOffsetTooBig + } + if offset < PacketStartOffset { + return errOffsetTooSmall + } + + // Write to the underlying device to skip filters. + _, err := t.tdev.Write(buf, offset) + return err +} + +// InjectInboundCopy takes a packet without leading space, +// reallocates it to conform to the InjectInbondDirect interface +// and calls InjectInboundDirect on it. Injecting a nil packet is a no-op. +func (t *TUN) InjectInboundCopy(packet []byte) error { + // We duplicate this check from InjectInboundDirect here + // to avoid wasting an allocation on an oversized packet. if len(packet) > MaxPacketSize { return errPacketTooBig } if len(packet) == 0 { return nil } - _, err := t.Write(packet, 0) - return err + + buf := make([]byte, PacketStartOffset+len(packet)) + copy(buf[PacketStartOffset:], packet) + + return t.InjectInboundDirect(buf, PacketStartOffset) } // InjectOutbound makes the TUN device behave as if a packet // with the given contents was sent to the network. // It does not block, but takes ownership of the packet. +// The injected packet will not pass through outbound filters. // Injecting an empty packet is a no-op. func (t *TUN) InjectOutbound(packet []byte) error { if len(packet) > MaxPacketSize { diff --git a/wgengine/tstun/tun_test.go b/wgengine/tstun/tun_test.go index 258ecc35d..c4c065fd0 100644 --- a/wgengine/tstun/tun_test.go +++ b/wgengine/tstun/tun_test.go @@ -58,7 +58,7 @@ func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *TUN) { if secure { setfilter(logf, tun) } else { - tun.insecure = true + tun.disableFilter = true } return chtun, tun } @@ -69,7 +69,7 @@ func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *TUN) { if secure { setfilter(logf, tun) } else { - tun.insecure = true + tun.disableFilter = true } return ftun.(*fakeTUN), tun } @@ -151,7 +151,7 @@ func TestWriteAndInject(t *testing.T) { for _, packet := range injected { go func(packet string) { payload := []byte(packet) - err := tun.InjectInbound(payload) + err := tun.InjectInboundCopy(payload) if err != nil { t.Errorf("%s: error: %v", packet, err) } diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 6e9bbd52b..ee69770a5 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -34,6 +34,7 @@ import ( "tailscale.com/wgengine/monitor" "tailscale.com/wgengine/packet" "tailscale.com/wgengine/router" + "tailscale.com/wgengine/tsdns" "tailscale.com/wgengine/tstun" ) @@ -54,6 +55,7 @@ type userspaceEngine struct { tundev *tstun.TUN wgdev *device.Device router router.Router + resolver *tsdns.Resolver magicConn *magicsock.Conn linkMon *monitor.Mon @@ -73,6 +75,28 @@ type userspaceEngine struct { // Lock ordering: wgLock, then mu. } +// RouterGen is the signature for a function that creates a +// router.Router. +type RouterGen func(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (router.Router, error) + +type EngineConfig struct { + // Logf is the logging function used by the engine. + Logf logger.Logf + // TUN is the tun device used by the engine. + TUN tun.Device + // RouterGen is the function used to instantiate the router. + RouterGen RouterGen + // ListenPort is the port on which the engine will listen. + ListenPort uint16 + // EchoRespondToAll determines whether ICMP Echo requests incoming from Tailscale peers + // will be intercepted and responded to, regardless of the source host. + EchoRespondToAll bool + // UseTailscaleDNS determines whether DNS requests for names of the form *.ipn.dev + // directed to the designated Taislcale DNS address (see wgengine/tsdns) + // will be intercepted and resolved by a tsdns.Resolver. + UseTailscaleDNS bool +} + type Loggify struct { f logger.Logf } @@ -84,8 +108,14 @@ func (l *Loggify) Write(b []byte) (int, error) { func NewFakeUserspaceEngine(logf logger.Logf, listenPort uint16) (Engine, error) { logf("Starting userspace wireguard engine (FAKE tuntap device).") - tundev := tstun.WrapTUN(logf, tstun.NewFakeTUN()) - return NewUserspaceEngineAdvanced(logf, tundev, router.NewFake, listenPort) + conf := EngineConfig{ + Logf: logf, + TUN: tstun.NewFakeTUN(), + RouterGen: router.NewFake, + ListenPort: listenPort, + EchoRespondToAll: true, + } + return NewUserspaceEngineAdvanced(conf) } // NewUserspaceEngine creates the named tun device and returns a @@ -104,38 +134,53 @@ func NewUserspaceEngine(logf logger.Logf, tunname string, listenPort uint16) (En return nil, err } logf("CreateTUN ok.") - tundev := tstun.WrapTUN(logf, tun) - e, err := NewUserspaceEngineAdvanced(logf, tundev, router.New, listenPort) + conf := EngineConfig{ + Logf: logf, + TUN: tun, + RouterGen: router.New, + ListenPort: listenPort, + // TODO(dmytro): plumb this down. + UseTailscaleDNS: true, + } + + e, err := NewUserspaceEngineAdvanced(conf) if err != nil { return nil, err } return e, err } -// RouterGen is the signature for a function that creates a -// router.Router. -type RouterGen func(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (router.Router, error) - -// NewUserspaceEngineAdvanced is like NewUserspaceEngine but takes a pre-created TUN device and allows specifing -// a custom router constructor and listening port. -func NewUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen RouterGen, listenPort uint16) (Engine, error) { - return newUserspaceEngineAdvanced(logf, tundev, routerGen, listenPort) +// NewUserspaceEngineAdvanced is like NewUserspaceEngine +// but provides control over all config fields. +func NewUserspaceEngineAdvanced(conf EngineConfig) (Engine, error) { + return newUserspaceEngineAdvanced(conf) } -func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen RouterGen, listenPort uint16) (_ Engine, reterr error) { +func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) { + logf := conf.Logf + e := &userspaceEngine{ - logf: logf, - reqCh: make(chan struct{}, 1), - waitCh: make(chan struct{}), - tundev: tundev, - pingers: make(map[wgcfg.Key]*pinger), + logf: logf, + reqCh: make(chan struct{}, 1), + waitCh: make(chan struct{}), + tundev: tstun.WrapTUN(logf, conf.TUN), + resolver: tsdns.NewResolver(logf), + pingers: make(map[wgcfg.Key]*pinger), } e.linkState, _ = getLinkState() + // Respond to all pings only in fake mode. + if conf.EchoRespondToAll { + e.tundev.PostFilterIn = echoRespondToAll + } + if conf.UseTailscaleDNS { + e.tundev.PreFilterOut = e.handleDNS + } + mon, err := monitor.New(logf, func() { e.LinkChange(false) }) if err != nil { - tundev.Close() + e.tundev.Close() return nil, err } e.linkMon = mon @@ -149,12 +194,12 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R } magicsockOpts := magicsock.Options{ Logf: logf, - Port: listenPort, + Port: conf.ListenPort, EndpointsFunc: endpointsFn, } e.magicConn, err = magicsock.NewConn(magicsockOpts) if err != nil { - tundev.Close() + e.tundev.Close() return nil, fmt.Errorf("wgengine: %v", err) } @@ -211,7 +256,7 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R // Pass the underlying tun.(*NativeDevice) to the router: // routers do not Read or Write, but do access native interfaces. - e.router, err = routerGen(logf, e.wgdev, e.tundev.Unwrap()) + e.router, err = conf.RouterGen(logf, e.wgdev, e.tundev.Unwrap()) if err != nil { e.magicConn.Close() return nil, err @@ -256,6 +301,37 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R return e, nil } +// echoRespondToAll is an inbound post-filter responding to all echo requests. +func echoRespondToAll(p *packet.ParsedPacket, t *tstun.TUN) filter.Response { + if p.IsEchoRequest() { + header := p.ICMPHeader() + header.ToResponse() + packet := packet.Generate(&header, p.Payload()) + t.InjectOutbound(packet) + // We already handled it, stop. + return filter.Drop + } + return filter.Accept +} + +// handleDNS is an outbound pre-filter resolving Tailscale domains. +func (e *userspaceEngine) handleDNS(p *packet.ParsedPacket, t *tstun.TUN) filter.Response { + if e.resolver.AcceptsPacket(p) { + // TODO(dmytro): avoid this allocation without having tsdns know tstun quirks. + buf := make([]byte, tstun.MaxPacketSize) + offset := tstun.PacketStartOffset + response, err := e.resolver.Respond(p, buf[offset:]) + if err != nil { + e.logf("DNS resolver error: %v", err) + } else { + t.InjectInboundDirect(buf[:offset+len(response)], offset) + } + // We already handled it, stop. + return filter.Drop + } + return filter.Accept +} + // pinger sends ping packets for a few seconds. // // These generated packets are used to ensure we trigger the spray logic in @@ -447,6 +523,10 @@ func (e *userspaceEngine) SetFilter(filt *filter.Filter) { e.tundev.SetFilter(filt) } +func (e *userspaceEngine) SetDNSMap(dm *tsdns.Map) { + e.resolver.SetMap(dm) +} + func (e *userspaceEngine) SetStatusCallback(cb StatusCallback) { e.mu.Lock() defer e.mu.Unlock() diff --git a/wgengine/watchdog.go b/wgengine/watchdog.go index ef9393a47..70568e504 100644 --- a/wgengine/watchdog.go +++ b/wgengine/watchdog.go @@ -15,6 +15,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/router" + "tailscale.com/wgengine/tsdns" ) // NewWatchdog wraps an Engine and makes sure that all methods complete @@ -74,6 +75,9 @@ func (e *watchdogEngine) GetFilter() *filter.Filter { func (e *watchdogEngine) SetFilter(filt *filter.Filter) { e.watchdog("SetFilter", func() { e.wrap.SetFilter(filt) }) } +func (e *watchdogEngine) SetDNSMap(dm *tsdns.Map) { + e.watchdog("SetDNSMap", func() { e.wrap.SetDNSMap(dm) }) +} func (e *watchdogEngine) SetStatusCallback(cb StatusCallback) { e.watchdog("SetStatusCallback", func() { e.wrap.SetStatusCallback(cb) }) } diff --git a/wgengine/watchdog_test.go b/wgengine/watchdog_test.go index 0e45dd641..4c86d8ebe 100644 --- a/wgengine/watchdog_test.go +++ b/wgengine/watchdog_test.go @@ -10,9 +10,6 @@ import ( "strings" "testing" "time" - - "tailscale.com/wgengine/router" - "tailscale.com/wgengine/tstun" ) func TestWatchdog(t *testing.T) { @@ -20,8 +17,7 @@ func TestWatchdog(t *testing.T) { t.Run("default watchdog does not fire", func(t *testing.T) { t.Parallel() - tun := tstun.WrapTUN(t.Logf, tstun.NewFakeTUN()) - e, err := NewUserspaceEngineAdvanced(t.Logf, tun, router.NewFake, 0) + e, err := NewFakeUserspaceEngine(t.Logf, 0) if err != nil { t.Fatal(err) } @@ -37,8 +33,7 @@ func TestWatchdog(t *testing.T) { t.Run("watchdog fires on blocked getStatus", func(t *testing.T) { t.Parallel() - tun := tstun.WrapTUN(t.Logf, tstun.NewFakeTUN()) - e, err := NewUserspaceEngineAdvanced(t.Logf, tun, router.NewFake, 0) + e, err := NewFakeUserspaceEngine(t.Logf, 0) if err != nil { t.Fatal(err) } diff --git a/wgengine/wgengine.go b/wgengine/wgengine.go index 81dcee80e..b583af2da 100644 --- a/wgengine/wgengine.go +++ b/wgengine/wgengine.go @@ -13,6 +13,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/router" + "tailscale.com/wgengine/tsdns" ) // ByteCount is the number of bytes that have been sent or received. @@ -65,6 +66,9 @@ type Engine interface { // SetFilter updates the packet filter. SetFilter(*filter.Filter) + // SetDNSMap updates the DNS map. + SetDNSMap(*tsdns.Map) + // SetStatusCallback sets the function to call when the // WireGuard status changes. SetStatusCallback(StatusCallback)