tsdns: initial implementation of a Tailscale DNS resolver (#396)

Signed-off-by: Dmytro Shynkevych <dmytro@tailscale.com>
reviewable/pr450/r1
Dmytro Shynkevych 5 years ago committed by GitHub
parent 5e1ee4be53
commit 511840b1f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,6 +26,7 @@ import (
"tailscale.com/wgengine" "tailscale.com/wgengine"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
"tailscale.com/wgengine/router" "tailscale.com/wgengine/router"
"tailscale.com/wgengine/tsdns"
) )
// LocalBackend is the glue between the major pieces of the Tailscale // LocalBackend is the glue between the major pieces of the Tailscale
@ -311,6 +312,7 @@ func (b *LocalBackend) Start(opts Options) error {
b.send(Notify{NetMap: newSt.NetMap}) b.send(Notify{NetMap: newSt.NetMap})
b.updateFilter(newSt.NetMap) b.updateFilter(newSt.NetMap)
b.updateDNSMap(newSt.NetMap)
if disableDERP { if disableDERP {
b.e.SetDERPMap(nil) b.e.SetDERPMap(nil)
} else { } else {
@ -427,6 +429,27 @@ func (b *LocalBackend) updateFilter(netMap *controlclient.NetworkMap) {
b.e.SetFilter(filter.New(netMap.PacketFilter, localNets, b.e.GetFilter(), b.logf)) b.e.SetFilter(filter.New(netMap.PacketFilter, localNets, b.e.GetFilter(), b.logf))
} }
// updateDNSMap updates the domain map in the DNS resolver in wgengine
// based on the given netMap and user preferences.
func (b *LocalBackend) updateDNSMap(netMap *controlclient.NetworkMap) {
if netMap == nil {
return
}
dnsMap := &tsdns.Map{DomainToIP: make(map[string]netaddr.IP)}
for _, peer := range netMap.Peers {
if len(peer.Addresses) == 0 {
continue
}
domain := peer.Hostinfo.Hostname
// Like PeerStatus.SimpleHostName()
domain = strings.TrimSuffix(domain, ".local")
domain = strings.TrimSuffix(domain, ".localdomain")
domain = domain + ".ipn.dev"
dnsMap.DomainToIP[domain] = netaddr.IPFrom16(peer.Addresses[0].IP.Addr)
}
b.e.SetDNSMap(dnsMap)
}
// readPoller is a goroutine that receives service lists from // readPoller is a goroutine that receives service lists from
// b.portpoll and propagates them into the controlclient's HostInfo. // b.portpoll and propagates them into the controlclient's HostInfo.
func (b *LocalBackend) readPoller() { func (b *LocalBackend) readPoller() {
@ -667,6 +690,7 @@ func (b *LocalBackend) SetPrefs(new *Prefs) {
} }
b.updateFilter(b.netMapCache) b.updateFilter(b.netMapCache)
b.updateDNSMap(b.netMapCache)
if old.WantRunning != new.WantRunning { if old.WantRunning != new.WantRunning {
b.stateMachine() b.stateMachine()
@ -799,6 +823,13 @@ func routerConfig(cfg *wgcfg.Config, prefs *Prefs, dnsDomains []string) *router.
rs.Routes = append(rs.Routes, wgCIDRToNetaddr(peer.AllowedIPs)...) rs.Routes = append(rs.Routes, wgCIDRToNetaddr(peer.AllowedIPs)...)
} }
// The Tailscale DNS IP.
// TODO(dmytro): make this configurable.
rs.Routes = append(rs.Routes, netaddr.IPPrefix{
IP: netaddr.IPv4(100, 100, 100, 100),
Bits: 32,
})
return rs return rs
} }

@ -6,7 +6,6 @@
package filter package filter
import ( import (
"fmt"
"sync" "sync"
"time" "time"
@ -137,7 +136,7 @@ func maybeHexdump(flag RunFlags, b []byte) string {
var acceptBucket = rate.NewLimiter(rate.Every(10*time.Second), 3) var acceptBucket = rate.NewLimiter(rate.Every(10*time.Second), 3)
var dropBucket = rate.NewLimiter(rate.Every(5*time.Second), 10) var dropBucket = rate.NewLimiter(rate.Every(5*time.Second), 10)
func (f *Filter) logRateLimit(runflags RunFlags, b []byte, q *packet.ParsedPacket, r Response, why string) { func (f *Filter) logRateLimit(runflags RunFlags, q *packet.ParsedPacket, r Response, why string) {
var verdict string var verdict string
if r == Drop && (runflags&LogDrops) != 0 && dropBucket.Allow() { if r == Drop && (runflags&LogDrops) != 0 && dropBucket.Allow() {
@ -151,36 +150,33 @@ func (f *Filter) logRateLimit(runflags RunFlags, b []byte, q *packet.ParsedPacke
// Note: it is crucial that q.String() be called only if {accept,drop}Bucket.Allow() passes, // Note: it is crucial that q.String() be called only if {accept,drop}Bucket.Allow() passes,
// since it causes an allocation. // since it causes an allocation.
if verdict != "" { if verdict != "" {
var qs string b := q.Buffer()
if q == nil { f.logf("%s: %s %d %s\n%s", verdict, q.String(), len(b), why, maybeHexdump(runflags, b))
qs = fmt.Sprintf("(%d bytes)", len(b))
} else {
qs = q.String()
}
f.logf("%s: %s %d %s\n%s", verdict, qs, len(b), why, maybeHexdump(runflags, b))
} }
} }
func (f *Filter) RunIn(b []byte, q *packet.ParsedPacket, rf RunFlags) Response { // RunIn determines whether this node is allowed to receive q from a Tailscale peer.
r := f.pre(b, q, rf) func (f *Filter) RunIn(q *packet.ParsedPacket, rf RunFlags) Response {
r := f.pre(q, rf)
if r == Accept || r == Drop { if r == Accept || r == Drop {
// already logged // already logged
return r return r
} }
r, why := f.runIn(q) r, why := f.runIn(q)
f.logRateLimit(rf, b, q, r, why) f.logRateLimit(rf, q, r, why)
return r return r
} }
func (f *Filter) RunOut(b []byte, q *packet.ParsedPacket, rf RunFlags) Response { // RunOut determines whether this node is allowed to send q to a Tailscale peer.
r := f.pre(b, q, rf) func (f *Filter) RunOut(q *packet.ParsedPacket, rf RunFlags) Response {
r := f.pre(q, rf)
if r == Drop || r == Accept { if r == Drop || r == Accept {
// already logged // already logged
return r return r
} }
r, why := f.runOut(q) r, why := f.runOut(q)
f.logRateLimit(rf, b, q, r, why) f.logRateLimit(rf, q, r, why)
return r return r
} }
@ -251,29 +247,28 @@ func (f *Filter) runOut(q *packet.ParsedPacket) (r Response, why string) {
return Accept, "ok out" return Accept, "ok out"
} }
func (f *Filter) pre(b []byte, q *packet.ParsedPacket, rf RunFlags) Response { func (f *Filter) pre(q *packet.ParsedPacket, rf RunFlags) Response {
if len(b) == 0 { if len(q.Buffer()) == 0 {
// wireguard keepalive packet, always permit. // wireguard keepalive packet, always permit.
return Accept return Accept
} }
if len(b) < 20 { if len(q.Buffer()) < 20 {
f.logRateLimit(rf, b, nil, Drop, "too short") f.logRateLimit(rf, q, Drop, "too short")
return Drop return Drop
} }
q.Decode(b)
switch q.IPProto { switch q.IPProto {
case packet.Unknown: case packet.Unknown:
// Unknown packets are dangerous; always drop them. // Unknown packets are dangerous; always drop them.
f.logRateLimit(rf, b, q, Drop, "unknown") f.logRateLimit(rf, q, Drop, "unknown")
return Drop return Drop
case packet.IPv6: case packet.IPv6:
f.logRateLimit(rf, b, q, Drop, "ipv6") f.logRateLimit(rf, q, Drop, "ipv6")
return Drop return Drop
case packet.Fragment: case packet.Fragment:
// Fragments after the first always need to be passed through. // Fragments after the first always need to be passed through.
// Very small fragments are considered Junk by ParsedPacket. // Very small fragments are considered Junk by ParsedPacket.
f.logRateLimit(rf, b, q, Accept, "fragment") f.logRateLimit(rf, q, Accept, "fragment")
return Accept return Accept
} }

@ -144,11 +144,12 @@ func TestNoAllocs(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
got := int(testing.AllocsPerRun(1000, func() { got := int(testing.AllocsPerRun(1000, func() {
var q ParsedPacket q := &ParsedPacket{}
q.Decode(test.packet)
if test.in { if test.in {
acl.RunIn(test.packet, &q, 0) acl.RunIn(q, 0)
} else { } else {
acl.RunOut(test.packet, &q, 0) acl.RunOut(q, 0)
} }
})) }))
@ -187,12 +188,13 @@ func BenchmarkFilter(b *testing.B) {
for _, bench := range benches { for _, bench := range benches {
b.Run(bench.name, func(b *testing.B) { b.Run(bench.name, func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
var q ParsedPacket q := &ParsedPacket{}
q.Decode(bench.packet)
// This branch seems to have no measurable impact on performance. // This branch seems to have no measurable impact on performance.
if bench.in { if bench.in {
acl.RunIn(bench.packet, &q, 0) acl.RunIn(q, 0)
} else { } else {
acl.RunOut(bench.packet, &q, 0) acl.RunOut(q, 0)
} }
} }
}) })
@ -215,7 +217,9 @@ func TestPreFilter(t *testing.T) {
} }
f := NewAllowNone(t.Logf) f := NewAllowNone(t.Logf)
for _, testPacket := range packets { for _, testPacket := range packets {
got := f.pre([]byte(testPacket.b), &ParsedPacket{}, LogDrops|LogAccepts) p := &ParsedPacket{}
p.Decode(testPacket.b)
got := f.pre(p, LogDrops|LogAccepts)
if got != testPacket.want { if got != testPacket.want {
t.Errorf("%q got=%v want=%v packet:\n%s", testPacket.desc, got, testPacket.want, packet.Hexdump(testPacket.b)) t.Errorf("%q got=%v want=%v packet:\n%s", testPacket.desc, got, testPacket.want, packet.Hexdump(testPacket.b))
} }

@ -102,7 +102,7 @@ func ipChecksum(b []byte) uint16 {
// It extracts only the subprotocol id, IP addresses, and (if any) ports, // It extracts only the subprotocol id, IP addresses, and (if any) ports,
// and shouldn't need any memory allocation. // and shouldn't need any memory allocation.
func (q *ParsedPacket) Decode(b []byte) { func (q *ParsedPacket) Decode(b []byte) {
q.b = nil q.b = b
if len(b) < ipHeaderLength { if len(b) < ipHeaderLength {
q.IPProto = Unknown q.IPProto = Unknown
@ -170,7 +170,6 @@ func (q *ParsedPacket) Decode(b []byte) {
} }
q.SrcPort = 0 q.SrcPort = 0
q.DstPort = 0 q.DstPort = 0
q.b = b
q.dataofs = q.subofs + icmpHeaderLength q.dataofs = q.subofs + icmpHeaderLength
return return
case TCP: case TCP:
@ -181,7 +180,6 @@ func (q *ParsedPacket) Decode(b []byte) {
q.SrcPort = get16(sub[0:2]) q.SrcPort = get16(sub[0:2])
q.DstPort = get16(sub[2:4]) q.DstPort = get16(sub[2:4])
q.TCPFlags = sub[13] & 0x3F q.TCPFlags = sub[13] & 0x3F
q.b = b
headerLength := (sub[12] & 0xF0) >> 2 headerLength := (sub[12] & 0xF0) >> 2
q.dataofs = q.subofs + int(headerLength) q.dataofs = q.subofs + int(headerLength)
return return
@ -192,7 +190,6 @@ func (q *ParsedPacket) Decode(b []byte) {
} }
q.SrcPort = get16(sub[0:2]) q.SrcPort = get16(sub[0:2])
q.DstPort = get16(sub[2:4]) q.DstPort = get16(sub[2:4])
q.b = b
q.dataofs = q.subofs + udpHeaderLength q.dataofs = q.subofs + udpHeaderLength
return return
default: default:
@ -244,6 +241,11 @@ func (q *ParsedPacket) UDPHeader() UDPHeader {
} }
} }
// Buffer returns the entire packet buffer.
func (q *ParsedPacket) Buffer() []byte {
return q.b
}
// Sub returns the IP subprotocol section. // Sub returns the IP subprotocol section.
func (q *ParsedPacket) Sub(begin, n int) []byte { func (q *ParsedPacket) Sub(begin, n int) []byte {
return q.b[q.subofs+begin : q.subofs+begin+n] return q.b[q.subofs+begin : q.subofs+begin+n]

@ -90,6 +90,7 @@ var ipv6PacketBuffer = []byte{
} }
var ipv6PacketDecode = ParsedPacket{ var ipv6PacketDecode = ParsedPacket{
b: ipv6PacketBuffer,
IPProto: IPv6, IPProto: IPv6,
} }
@ -100,6 +101,7 @@ var unknownPacketBuffer = []byte{
} }
var unknownPacketDecode = ParsedPacket{ var unknownPacketDecode = ParsedPacket{
b: unknownPacketBuffer,
IPProto: Unknown, IPProto: Unknown,
} }

@ -0,0 +1,274 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tsdns provides a Resolver struct capable of resolving
// domains on a Tailscale network.
package tsdns
import (
"encoding/binary"
"errors"
"strings"
"sync"
dns "golang.org/x/net/dns/dnsmessage"
"inet.af/netaddr"
"tailscale.com/types/logger"
"tailscale.com/wgengine/packet"
)
// defaultTTL is the TTL in seconds of all responses from Resolver.
const defaultTTL = 600
var (
errMapNotSet = errors.New("domain map not set")
errNoSuchDomain = errors.New("domain does not exist")
errNotImplemented = errors.New("query type not implemented")
errNotOurName = errors.New("not an *.ipn.dev domain")
errNotQuery = errors.New("not a DNS query")
)
var (
defaultIP = packet.IP(binary.BigEndian.Uint32([]byte{100, 100, 100, 100}))
defaultPort = uint16(53)
)
// Map is all the data Resolver needs to resolve DNS queries.
type Map struct {
// DomainToIP is a mapping of Tailscale domains to their IP addresses.
// For example, monitoring.ipn.dev -> 100.64.0.1.
DomainToIP map[string]netaddr.IP
}
// Resolver is a DNS resolver for domain names of the form *.ipn.dev
// It is intended
type Resolver struct {
logf logger.Logf
// ip is the IP on which the resolver is listening.
ip packet.IP
// port is the port on which the resolver is listening.
port uint16
// mu guards the following fields from being updated while used.
mu sync.Mutex
// dnsMap is the map most recently received from the control server.
dnsMap *Map
}
// NewResolver constructs a resolver with default parameters.
func NewResolver(logf logger.Logf) *Resolver {
r := &Resolver{
logf: logf,
ip: defaultIP,
port: defaultPort,
}
return r
}
// AcceptsPacket determines if the given packet is
// directed to this resolver (by ip and port).
// We also require that UDP be used to simplify things for now.
func (r *Resolver) AcceptsPacket(in *packet.ParsedPacket) bool {
return in.DstIP == r.ip && in.DstPort == r.port && in.IPProto == packet.UDP
}
// SetMap sets the resolver's DNS map.
func (r *Resolver) SetMap(m *Map) {
r.mu.Lock()
r.dnsMap = m
r.mu.Unlock()
}
// Resolve maps a given domain name to the IP address of the host that owns it.
func (r *Resolver) Resolve(domain string) (netaddr.IP, dns.RCode, error) {
// If not a subdomain of ipn.dev, then we must refuse this query.
// We do this before checking the map to distinguish beween nonexistent domains
// and misdirected queries.
if !strings.HasSuffix(domain, ".ipn.dev") {
return netaddr.IP{}, dns.RCodeRefused, errNotOurName
}
r.mu.Lock()
if r.dnsMap == nil {
r.mu.Unlock()
return netaddr.IP{}, dns.RCodeServerFailure, errMapNotSet
}
addr, found := r.dnsMap.DomainToIP[domain]
r.mu.Unlock()
if !found {
return netaddr.IP{}, dns.RCodeNameError, errNoSuchDomain
}
return addr, dns.RCodeSuccess, nil
}
type response struct {
Header dns.Header
ResourceHeader dns.ResourceHeader
Question dns.Question
// TODO(dmytro): support IPv6.
IP netaddr.IP
}
// parseQuery parses the query in given packet into a response struct.
func (r *Resolver) parseQuery(query *packet.ParsedPacket, resp *response) error {
var parser dns.Parser
var err error
resp.Header, err = parser.Start(query.Payload())
if err != nil {
return err
}
if resp.Header.Response {
return errNotQuery
}
resp.Question, err = parser.Question()
if err != nil {
return err
}
return nil
}
// makeResponse resolves the question stored in resp and sets the answer fields.
func (r *Resolver) makeResponse(resp *response) error {
var err error
name := resp.Question.Name.String()
if len(name) > 0 {
name = name[:len(name)-1]
}
if resp.Question.Type == dns.TypeA {
// Remove final dot from name: *.ipn.dev. -> *.ipn.dev
resp.IP, resp.Header.RCode, err = r.Resolve(name)
} else {
resp.Header.RCode = dns.RCodeNotImplemented
err = errNotImplemented
}
return err
}
// marshalAnswer serializes the answer record into an active builder.
func marshalAnswer(resp *response, builder *dns.Builder) error {
var answer dns.AResource
err := builder.StartAnswers()
if err != nil {
return err
}
answerHeader := dns.ResourceHeader{
Name: resp.Question.Name,
Type: dns.TypeA,
Class: dns.ClassINET,
TTL: defaultTTL,
}
ip := resp.IP.As16()
copy(answer.A[:], ip[12:])
return builder.AResource(answerHeader, answer)
}
// marshalResponse serializes the DNS response into an active builder.
func marshalResponse(resp *response, builder *dns.Builder) ([]byte, error) {
resp.Header.Response = true
resp.Header.Authoritative = true
if resp.Header.RecursionDesired {
resp.Header.RecursionAvailable = true
}
err := builder.StartQuestions()
if err != nil {
return nil, err
}
err = builder.Question(resp.Question)
if err != nil {
return nil, err
}
if resp.Header.RCode == dns.RCodeSuccess {
err = marshalAnswer(resp, builder)
if err != nil {
return nil, err
}
}
return builder.Finish()
}
func marshalResponsePacket(query *packet.ParsedPacket, resp *response, buf []byte) ([]byte, error) {
udpHeader := query.UDPHeader()
udpHeader.ToResponse()
offset := udpHeader.Len()
// dns.Builder appends to the passed buffer (without reallocation when possible),
// so we pass in a zero-length slice starting at the point it should start writing.
builder := dns.NewBuilder(buf[offset:offset], resp.Header)
// rbuf is the response slice with the correct length starting at offset.
rbuf, err := marshalResponse(resp, &builder)
if err != nil {
return nil, err
}
end := offset + len(rbuf)
err = udpHeader.Marshal(buf[:end])
if err != nil {
return nil, err
}
return buf[:end], nil
}
// Respond writes a response to query into buf and returns buf trimmed to the response length.
// It is assumed that r.AcceptsPacket(query) is true.
func (r *Resolver) Respond(query *packet.ParsedPacket, buf []byte) ([]byte, error) {
var resp response
var err error
// 0. Verify that contract is upheld.
if !r.AcceptsPacket(query) {
r.logf("[unexpected] tsdns: Respond called on query not for this resolver")
resp.Header.RCode = dns.RCodeServerFailure
return marshalResponsePacket(query, &resp, buf)
}
// A DNS response is at least as long as the query
if len(buf) < len(query.Buffer()) {
r.logf("[unexpected] tsdns: response buffer is too small")
resp.Header.RCode = dns.RCodeServerFailure
return marshalResponsePacket(query, &resp, buf)
}
// 1. Parse query packet.
err = r.parseQuery(query, &resp)
// We will not return this error: it is the sender's fault.
if err != nil {
r.logf("tsdns: error during query parsing: %v", err)
resp.Header.RCode = dns.RCodeFormatError
return marshalResponsePacket(query, &resp, buf)
}
// 2. Service the query.
err = r.makeResponse(&resp)
// We will not return this error: it is the sender's fault.
if err != nil {
r.logf("tsdns: error during name resolution: %v", err)
return marshalResponsePacket(query, &resp, buf)
}
// For now, we require IPv4 in all cases.
// If we somehow came up with a non-IPv4 address, it's our fault.
if !resp.IP.Is4() {
resp.Header.RCode = dns.RCodeServerFailure
r.logf("tsdns: error during name resolution: IPv6 address: %v", resp.IP)
}
// 3. Serialize the response.
return marshalResponsePacket(query, &resp, buf)
}

@ -10,6 +10,7 @@ import (
"errors" "errors"
"io" "io"
"os" "os"
"sync"
"sync/atomic" "sync/atomic"
"github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/device"
@ -19,10 +20,12 @@ import (
"tailscale.com/wgengine/packet" "tailscale.com/wgengine/packet"
) )
const ( const maxBufferSize = device.MaxMessageSize
readMaxSize = device.MaxMessageSize
readOffset = device.MessageTransportHeaderSize // PacketStartOffset is the minimal amount of leading space that must exist
) // before &packet[offset] in a packet passed to Read, Write, or InjectInboundDirect.
// This is necessary to avoid reallocation in wireguard-go internals.
const PacketStartOffset = device.MessageTransportHeaderSize
// MaxPacketSize is the maximum size (in bytes) // MaxPacketSize is the maximum size (in bytes)
// of a packet that can be injected into a tstun.TUN. // of a packet that can be injected into a tstun.TUN.
@ -35,7 +38,15 @@ var (
ErrFiltered = errors.New("packet dropped by filter") ErrFiltered = errors.New("packet dropped by filter")
) )
var errPacketTooBig = errors.New("packet too big") var (
errPacketTooBig = errors.New("packet too big")
errOffsetTooBig = errors.New("offset larger than buffer length")
errOffsetTooSmall = errors.New("offset smaller than PacketStartOffset")
)
// FilterFunc is a packet-filtering function with access to the TUN device.
// It must not hold onto the packet struct, as its backing storage will be reused.
type FilterFunc func(*packet.ParsedPacket, *TUN) filter.Response
// TUN wraps a tun.Device from wireguard-go, // TUN wraps a tun.Device from wireguard-go,
// augmenting it with filtering and packet injection. // augmenting it with filtering and packet injection.
@ -47,10 +58,14 @@ type TUN struct {
tdev tun.Device tdev tun.Device
// buffer stores the oldest unconsumed packet from tdev. // buffer stores the oldest unconsumed packet from tdev.
// It is made a static buffer in order to avoid graticious allocation. // It is made a static buffer in order to avoid allocations.
buffer [readMaxSize]byte buffer [maxBufferSize]byte
// bufferConsumed synchronizes access to buffer (shared by Read and poll). // bufferConsumed synchronizes access to buffer (shared by Read and poll).
bufferConsumed chan struct{} bufferConsumed chan struct{}
// parsedPacketPool holds a pool of ParsedPacket structs for use in filtering.
// This is needed because escape analysis cannot see that parsed packets
// do not escape through {Pre,Post}Filter{In,Out}.
parsedPacketPool sync.Pool // of *packet.ParsedPacket
// closed signals poll (by closing) when the device is closed. // closed signals poll (by closing) when the device is closed.
closed chan struct{} closed chan struct{}
@ -73,8 +88,19 @@ type TUN struct {
// filterFlags control the verbosity of logging packet drops/accepts. // filterFlags control the verbosity of logging packet drops/accepts.
filterFlags filter.RunFlags filterFlags filter.RunFlags
// insecure disables all filtering when set. This is useful in tests. // PreFilterIn is the inbound filter function that runs before the main filter
insecure bool // and therefore sees the packets that may be later dropped by it.
PreFilterIn FilterFunc
// PostFilterIn is the inbound filter function that runs after the main filter.
PostFilterIn FilterFunc
// PreFilterOut is the outbound filter function that runs before the main filter
// and therefore sees the packets that may be later dropped by it.
PreFilterOut FilterFunc
// PostFilterOut is the outbound filter function that runs after the main filter.
PostFilterOut FilterFunc
// disableFilter disables all filtering when set. This should only be used in tests.
disableFilter bool
} }
func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN { func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN {
@ -87,8 +113,14 @@ func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN {
closed: make(chan struct{}), closed: make(chan struct{}),
errors: make(chan error), errors: make(chan error),
outbound: make(chan []byte), outbound: make(chan []byte),
// TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets.
filterFlags: filter.LogAccepts | filter.LogDrops, filterFlags: filter.LogAccepts | filter.LogDrops,
} }
tun.parsedPacketPool.New = func() interface{} {
return new(packet.ParsedPacket)
}
go tun.poll() go tun.poll()
// The buffer starts out consumed. // The buffer starts out consumed.
tun.bufferConsumed <- struct{}{} tun.bufferConsumed <- struct{}{}
@ -140,10 +172,10 @@ func (t *TUN) poll() {
// continue // continue
} }
// Read may use memory in t.buffer before readOffset for mandatory headers. // Read may use memory in t.buffer before PacketStartOffset for mandatory headers.
// This is the rationale behind the tun.TUN.{Read,Write} interfaces // This is the rationale behind the tun.TUN.{Read,Write} interfaces
// and the reason t.buffer has size MaxMessageSize and not MaxContentSize. // and the reason t.buffer has size MaxMessageSize and not MaxContentSize.
n, err := t.tdev.Read(t.buffer[:], readOffset) n, err := t.tdev.Read(t.buffer[:], PacketStartOffset)
if err != nil { if err != nil {
select { select {
case <-t.closed: case <-t.closed:
@ -165,26 +197,41 @@ func (t *TUN) poll() {
select { select {
case <-t.closed: case <-t.closed:
return return
case t.outbound <- t.buffer[readOffset : readOffset+n]: case t.outbound <- t.buffer[PacketStartOffset : PacketStartOffset+n]:
// continue // continue
} }
} }
} }
func (t *TUN) filterOut(buf []byte) filter.Response { func (t *TUN) filterOut(buf []byte) filter.Response {
p := t.parsedPacketPool.Get().(*packet.ParsedPacket)
defer t.parsedPacketPool.Put(p)
p.Decode(buf)
if t.PreFilterOut != nil {
if t.PreFilterOut(p, t) == filter.Drop {
return filter.Drop
}
}
filt, _ := t.filter.Load().(*filter.Filter) filt, _ := t.filter.Load().(*filter.Filter)
if filt == nil { if filt == nil {
t.logf("Warning: you forgot to use SetFilter()! Packet dropped.") t.logf("tstun: warning: you forgot to use SetFilter()! Packet dropped.")
return filter.Drop return filter.Drop
} }
var p packet.ParsedPacket if filt.RunOut(p, t.filterFlags) != filter.Accept {
if filt.RunOut(buf, &p, t.filterFlags) == filter.Accept { return filter.Drop
return filter.Accept
} }
if t.PostFilterOut != nil {
if t.PostFilterOut(p, t) == filter.Drop {
return filter.Drop return filter.Drop
}
}
return filter.Accept
} }
func (t *TUN) Read(buf []byte, offset int) (int, error) { func (t *TUN) Read(buf []byte, offset int) (int, error) {
@ -200,12 +247,16 @@ func (t *TUN) Read(buf []byte, offset int) (int, error) {
// t.buffer has a fixed location in memory, // t.buffer has a fixed location in memory,
// so this is the easiest way to tell when it has been consumed. // so this is the easiest way to tell when it has been consumed.
// &packet[0] can be used because empty packets do not reach t.outbound. // &packet[0] can be used because empty packets do not reach t.outbound.
if &packet[0] == &t.buffer[readOffset] { if &packet[0] == &t.buffer[PacketStartOffset] {
t.bufferConsumed <- struct{}{} t.bufferConsumed <- struct{}{}
} else {
// If the packet is not from t.buffer, then it is an injected packet.
// In this case, we return eary to bypass filtering
return n, nil
} }
} }
if !t.insecure { if !t.disableFilter {
response := t.filterOut(buf[offset : offset+n]) response := t.filterOut(buf[offset : offset+n])
if response != filter.Accept { if response != filter.Accept {
// Wireguard considers read errors fatal; pretend nothing was read // Wireguard considers read errors fatal; pretend nothing was read
@ -217,35 +268,38 @@ func (t *TUN) Read(buf []byte, offset int) (int, error) {
} }
func (t *TUN) filterIn(buf []byte) filter.Response { func (t *TUN) filterIn(buf []byte) filter.Response {
p := t.parsedPacketPool.Get().(*packet.ParsedPacket)
defer t.parsedPacketPool.Put(p)
p.Decode(buf)
if t.PreFilterIn != nil {
if t.PreFilterIn(p, t) == filter.Drop {
return filter.Drop
}
}
filt, _ := t.filter.Load().(*filter.Filter) filt, _ := t.filter.Load().(*filter.Filter)
if filt == nil { if filt == nil {
t.logf("Warning: you forgot to use SetFilter()! Packet dropped.") t.logf("tstun: warning: you forgot to use SetFilter()! Packet dropped.")
return filter.Drop return filter.Drop
} }
var p packet.ParsedPacket if filt.RunIn(p, t.filterFlags) != filter.Accept {
if filt.RunIn(buf, &p, t.filterFlags) == filter.Accept {
// Only in fake mode, answer any incoming pings.
if p.IsEchoRequest() {
ft, ok := t.tdev.(*fakeTUN)
if ok {
header := p.ICMPHeader()
header.ToResponse()
packet := packet.Generate(&header, p.Payload())
ft.Write(packet, 0)
// We already handled it, stop.
return filter.Drop return filter.Drop
} }
if t.PostFilterIn != nil {
if t.PostFilterIn(p, t) == filter.Drop {
return filter.Drop
} }
return filter.Accept
} }
return filter.Drop return filter.Accept
} }
func (t *TUN) Write(buf []byte, offset int) (int, error) { func (t *TUN) Write(buf []byte, offset int) (int, error) {
if !t.insecure { if !t.disableFilter {
response := t.filterIn(buf[offset:]) response := t.filterIn(buf[offset:])
if response != filter.Accept { if response != filter.Accept {
return 0, ErrFiltered return 0, ErrFiltered
@ -264,24 +318,53 @@ func (t *TUN) SetFilter(filt *filter.Filter) {
t.filter.Store(filt) t.filter.Store(filt)
} }
// InjectInbound makes the TUN device behave as if a packet // InjectInboundDirect makes the TUN device behave as if a packet
// with the given contents was received from the network. // with the given contents was received from the network.
// It blocks and does not take ownership of the packet. // It blocks and does not take ownership of the packet.
// Injecting an empty packet is a no-op. // The injected packet will not pass through inbound filters.
func (t *TUN) InjectInbound(packet []byte) error { //
// The packet contents are to start at &buf[offset].
// offset must be greater or equal to PacketStartOffset.
// The space before &buf[offset] will be used by Wireguard.
func (t *TUN) InjectInboundDirect(buf []byte, offset int) error {
if len(buf) > MaxPacketSize {
return errPacketTooBig
}
if len(buf) < offset {
return errOffsetTooBig
}
if offset < PacketStartOffset {
return errOffsetTooSmall
}
// Write to the underlying device to skip filters.
_, err := t.tdev.Write(buf, offset)
return err
}
// InjectInboundCopy takes a packet without leading space,
// reallocates it to conform to the InjectInbondDirect interface
// and calls InjectInboundDirect on it. Injecting a nil packet is a no-op.
func (t *TUN) InjectInboundCopy(packet []byte) error {
// We duplicate this check from InjectInboundDirect here
// to avoid wasting an allocation on an oversized packet.
if len(packet) > MaxPacketSize { if len(packet) > MaxPacketSize {
return errPacketTooBig return errPacketTooBig
} }
if len(packet) == 0 { if len(packet) == 0 {
return nil return nil
} }
_, err := t.Write(packet, 0)
return err buf := make([]byte, PacketStartOffset+len(packet))
copy(buf[PacketStartOffset:], packet)
return t.InjectInboundDirect(buf, PacketStartOffset)
} }
// InjectOutbound makes the TUN device behave as if a packet // InjectOutbound makes the TUN device behave as if a packet
// with the given contents was sent to the network. // with the given contents was sent to the network.
// It does not block, but takes ownership of the packet. // It does not block, but takes ownership of the packet.
// The injected packet will not pass through outbound filters.
// Injecting an empty packet is a no-op. // Injecting an empty packet is a no-op.
func (t *TUN) InjectOutbound(packet []byte) error { func (t *TUN) InjectOutbound(packet []byte) error {
if len(packet) > MaxPacketSize { if len(packet) > MaxPacketSize {

@ -58,7 +58,7 @@ func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *TUN) {
if secure { if secure {
setfilter(logf, tun) setfilter(logf, tun)
} else { } else {
tun.insecure = true tun.disableFilter = true
} }
return chtun, tun return chtun, tun
} }
@ -69,7 +69,7 @@ func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *TUN) {
if secure { if secure {
setfilter(logf, tun) setfilter(logf, tun)
} else { } else {
tun.insecure = true tun.disableFilter = true
} }
return ftun.(*fakeTUN), tun return ftun.(*fakeTUN), tun
} }
@ -151,7 +151,7 @@ func TestWriteAndInject(t *testing.T) {
for _, packet := range injected { for _, packet := range injected {
go func(packet string) { go func(packet string) {
payload := []byte(packet) payload := []byte(packet)
err := tun.InjectInbound(payload) err := tun.InjectInboundCopy(payload)
if err != nil { if err != nil {
t.Errorf("%s: error: %v", packet, err) t.Errorf("%s: error: %v", packet, err)
} }

@ -34,6 +34,7 @@ import (
"tailscale.com/wgengine/monitor" "tailscale.com/wgengine/monitor"
"tailscale.com/wgengine/packet" "tailscale.com/wgengine/packet"
"tailscale.com/wgengine/router" "tailscale.com/wgengine/router"
"tailscale.com/wgengine/tsdns"
"tailscale.com/wgengine/tstun" "tailscale.com/wgengine/tstun"
) )
@ -54,6 +55,7 @@ type userspaceEngine struct {
tundev *tstun.TUN tundev *tstun.TUN
wgdev *device.Device wgdev *device.Device
router router.Router router router.Router
resolver *tsdns.Resolver
magicConn *magicsock.Conn magicConn *magicsock.Conn
linkMon *monitor.Mon linkMon *monitor.Mon
@ -73,6 +75,28 @@ type userspaceEngine struct {
// Lock ordering: wgLock, then mu. // Lock ordering: wgLock, then mu.
} }
// RouterGen is the signature for a function that creates a
// router.Router.
type RouterGen func(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (router.Router, error)
type EngineConfig struct {
// Logf is the logging function used by the engine.
Logf logger.Logf
// TUN is the tun device used by the engine.
TUN tun.Device
// RouterGen is the function used to instantiate the router.
RouterGen RouterGen
// ListenPort is the port on which the engine will listen.
ListenPort uint16
// EchoRespondToAll determines whether ICMP Echo requests incoming from Tailscale peers
// will be intercepted and responded to, regardless of the source host.
EchoRespondToAll bool
// UseTailscaleDNS determines whether DNS requests for names of the form *.ipn.dev
// directed to the designated Taislcale DNS address (see wgengine/tsdns)
// will be intercepted and resolved by a tsdns.Resolver.
UseTailscaleDNS bool
}
type Loggify struct { type Loggify struct {
f logger.Logf f logger.Logf
} }
@ -84,8 +108,14 @@ func (l *Loggify) Write(b []byte) (int, error) {
func NewFakeUserspaceEngine(logf logger.Logf, listenPort uint16) (Engine, error) { func NewFakeUserspaceEngine(logf logger.Logf, listenPort uint16) (Engine, error) {
logf("Starting userspace wireguard engine (FAKE tuntap device).") logf("Starting userspace wireguard engine (FAKE tuntap device).")
tundev := tstun.WrapTUN(logf, tstun.NewFakeTUN()) conf := EngineConfig{
return NewUserspaceEngineAdvanced(logf, tundev, router.NewFake, listenPort) Logf: logf,
TUN: tstun.NewFakeTUN(),
RouterGen: router.NewFake,
ListenPort: listenPort,
EchoRespondToAll: true,
}
return NewUserspaceEngineAdvanced(conf)
} }
// NewUserspaceEngine creates the named tun device and returns a // NewUserspaceEngine creates the named tun device and returns a
@ -104,38 +134,53 @@ func NewUserspaceEngine(logf logger.Logf, tunname string, listenPort uint16) (En
return nil, err return nil, err
} }
logf("CreateTUN ok.") logf("CreateTUN ok.")
tundev := tstun.WrapTUN(logf, tun)
e, err := NewUserspaceEngineAdvanced(logf, tundev, router.New, listenPort) conf := EngineConfig{
Logf: logf,
TUN: tun,
RouterGen: router.New,
ListenPort: listenPort,
// TODO(dmytro): plumb this down.
UseTailscaleDNS: true,
}
e, err := NewUserspaceEngineAdvanced(conf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return e, err return e, err
} }
// RouterGen is the signature for a function that creates a // NewUserspaceEngineAdvanced is like NewUserspaceEngine
// router.Router. // but provides control over all config fields.
type RouterGen func(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (router.Router, error) func NewUserspaceEngineAdvanced(conf EngineConfig) (Engine, error) {
return newUserspaceEngineAdvanced(conf)
// NewUserspaceEngineAdvanced is like NewUserspaceEngine but takes a pre-created TUN device and allows specifing
// a custom router constructor and listening port.
func NewUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen RouterGen, listenPort uint16) (Engine, error) {
return newUserspaceEngineAdvanced(logf, tundev, routerGen, listenPort)
} }
func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen RouterGen, listenPort uint16) (_ Engine, reterr error) { func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) {
logf := conf.Logf
e := &userspaceEngine{ e := &userspaceEngine{
logf: logf, logf: logf,
reqCh: make(chan struct{}, 1), reqCh: make(chan struct{}, 1),
waitCh: make(chan struct{}), waitCh: make(chan struct{}),
tundev: tundev, tundev: tstun.WrapTUN(logf, conf.TUN),
resolver: tsdns.NewResolver(logf),
pingers: make(map[wgcfg.Key]*pinger), pingers: make(map[wgcfg.Key]*pinger),
} }
e.linkState, _ = getLinkState() e.linkState, _ = getLinkState()
// Respond to all pings only in fake mode.
if conf.EchoRespondToAll {
e.tundev.PostFilterIn = echoRespondToAll
}
if conf.UseTailscaleDNS {
e.tundev.PreFilterOut = e.handleDNS
}
mon, err := monitor.New(logf, func() { e.LinkChange(false) }) mon, err := monitor.New(logf, func() { e.LinkChange(false) })
if err != nil { if err != nil {
tundev.Close() e.tundev.Close()
return nil, err return nil, err
} }
e.linkMon = mon e.linkMon = mon
@ -149,12 +194,12 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R
} }
magicsockOpts := magicsock.Options{ magicsockOpts := magicsock.Options{
Logf: logf, Logf: logf,
Port: listenPort, Port: conf.ListenPort,
EndpointsFunc: endpointsFn, EndpointsFunc: endpointsFn,
} }
e.magicConn, err = magicsock.NewConn(magicsockOpts) e.magicConn, err = magicsock.NewConn(magicsockOpts)
if err != nil { if err != nil {
tundev.Close() e.tundev.Close()
return nil, fmt.Errorf("wgengine: %v", err) return nil, fmt.Errorf("wgengine: %v", err)
} }
@ -211,7 +256,7 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R
// Pass the underlying tun.(*NativeDevice) to the router: // Pass the underlying tun.(*NativeDevice) to the router:
// routers do not Read or Write, but do access native interfaces. // routers do not Read or Write, but do access native interfaces.
e.router, err = routerGen(logf, e.wgdev, e.tundev.Unwrap()) e.router, err = conf.RouterGen(logf, e.wgdev, e.tundev.Unwrap())
if err != nil { if err != nil {
e.magicConn.Close() e.magicConn.Close()
return nil, err return nil, err
@ -256,6 +301,37 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R
return e, nil return e, nil
} }
// echoRespondToAll is an inbound post-filter responding to all echo requests.
func echoRespondToAll(p *packet.ParsedPacket, t *tstun.TUN) filter.Response {
if p.IsEchoRequest() {
header := p.ICMPHeader()
header.ToResponse()
packet := packet.Generate(&header, p.Payload())
t.InjectOutbound(packet)
// We already handled it, stop.
return filter.Drop
}
return filter.Accept
}
// handleDNS is an outbound pre-filter resolving Tailscale domains.
func (e *userspaceEngine) handleDNS(p *packet.ParsedPacket, t *tstun.TUN) filter.Response {
if e.resolver.AcceptsPacket(p) {
// TODO(dmytro): avoid this allocation without having tsdns know tstun quirks.
buf := make([]byte, tstun.MaxPacketSize)
offset := tstun.PacketStartOffset
response, err := e.resolver.Respond(p, buf[offset:])
if err != nil {
e.logf("DNS resolver error: %v", err)
} else {
t.InjectInboundDirect(buf[:offset+len(response)], offset)
}
// We already handled it, stop.
return filter.Drop
}
return filter.Accept
}
// pinger sends ping packets for a few seconds. // pinger sends ping packets for a few seconds.
// //
// These generated packets are used to ensure we trigger the spray logic in // These generated packets are used to ensure we trigger the spray logic in
@ -447,6 +523,10 @@ func (e *userspaceEngine) SetFilter(filt *filter.Filter) {
e.tundev.SetFilter(filt) e.tundev.SetFilter(filt)
} }
func (e *userspaceEngine) SetDNSMap(dm *tsdns.Map) {
e.resolver.SetMap(dm)
}
func (e *userspaceEngine) SetStatusCallback(cb StatusCallback) { func (e *userspaceEngine) SetStatusCallback(cb StatusCallback) {
e.mu.Lock() e.mu.Lock()
defer e.mu.Unlock() defer e.mu.Unlock()

@ -15,6 +15,7 @@ import (
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
"tailscale.com/wgengine/router" "tailscale.com/wgengine/router"
"tailscale.com/wgengine/tsdns"
) )
// NewWatchdog wraps an Engine and makes sure that all methods complete // NewWatchdog wraps an Engine and makes sure that all methods complete
@ -74,6 +75,9 @@ func (e *watchdogEngine) GetFilter() *filter.Filter {
func (e *watchdogEngine) SetFilter(filt *filter.Filter) { func (e *watchdogEngine) SetFilter(filt *filter.Filter) {
e.watchdog("SetFilter", func() { e.wrap.SetFilter(filt) }) e.watchdog("SetFilter", func() { e.wrap.SetFilter(filt) })
} }
func (e *watchdogEngine) SetDNSMap(dm *tsdns.Map) {
e.watchdog("SetDNSMap", func() { e.wrap.SetDNSMap(dm) })
}
func (e *watchdogEngine) SetStatusCallback(cb StatusCallback) { func (e *watchdogEngine) SetStatusCallback(cb StatusCallback) {
e.watchdog("SetStatusCallback", func() { e.wrap.SetStatusCallback(cb) }) e.watchdog("SetStatusCallback", func() { e.wrap.SetStatusCallback(cb) })
} }

@ -10,9 +10,6 @@ import (
"strings" "strings"
"testing" "testing"
"time" "time"
"tailscale.com/wgengine/router"
"tailscale.com/wgengine/tstun"
) )
func TestWatchdog(t *testing.T) { func TestWatchdog(t *testing.T) {
@ -20,8 +17,7 @@ func TestWatchdog(t *testing.T) {
t.Run("default watchdog does not fire", func(t *testing.T) { t.Run("default watchdog does not fire", func(t *testing.T) {
t.Parallel() t.Parallel()
tun := tstun.WrapTUN(t.Logf, tstun.NewFakeTUN()) e, err := NewFakeUserspaceEngine(t.Logf, 0)
e, err := NewUserspaceEngineAdvanced(t.Logf, tun, router.NewFake, 0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -37,8 +33,7 @@ func TestWatchdog(t *testing.T) {
t.Run("watchdog fires on blocked getStatus", func(t *testing.T) { t.Run("watchdog fires on blocked getStatus", func(t *testing.T) {
t.Parallel() t.Parallel()
tun := tstun.WrapTUN(t.Logf, tstun.NewFakeTUN()) e, err := NewFakeUserspaceEngine(t.Logf, 0)
e, err := NewUserspaceEngineAdvanced(t.Logf, tun, router.NewFake, 0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

@ -13,6 +13,7 @@ import (
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
"tailscale.com/wgengine/router" "tailscale.com/wgengine/router"
"tailscale.com/wgengine/tsdns"
) )
// ByteCount is the number of bytes that have been sent or received. // ByteCount is the number of bytes that have been sent or received.
@ -65,6 +66,9 @@ type Engine interface {
// SetFilter updates the packet filter. // SetFilter updates the packet filter.
SetFilter(*filter.Filter) SetFilter(*filter.Filter)
// SetDNSMap updates the DNS map.
SetDNSMap(*tsdns.Map)
// SetStatusCallback sets the function to call when the // SetStatusCallback sets the function to call when the
// WireGuard status changes. // WireGuard status changes.
SetStatusCallback(StatusCallback) SetStatusCallback(StatusCallback)

Loading…
Cancel
Save