diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index a11e47885..738c17072 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -346,6 +346,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de W 💣 tailscale.com/util/osdiag/internal/wsc from tailscale.com/util/osdiag tailscale.com/util/osshare from tailscale.com/ipn/ipnlocal+ W tailscale.com/util/pidowner from tailscale.com/ipn/ipnauth + tailscale.com/util/race from tailscale.com/net/dns/resolver tailscale.com/util/racebuild from tailscale.com/logpolicy tailscale.com/util/rands from tailscale.com/ipn/ipnlocal+ tailscale.com/util/ringbuffer from tailscale.com/wgengine/magicsock diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index edcf9bbe6..4a0bbc7fc 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -18,6 +18,7 @@ import ( "sort" "strings" "sync" + "sync/atomic" "time" dns "golang.org/x/net/dns/dnsmessage" @@ -35,6 +36,7 @@ import ( "tailscale.com/types/nettype" "tailscale.com/util/cloudenv" "tailscale.com/util/dnsname" + "tailscale.com/util/race" "tailscale.com/version" ) @@ -70,6 +72,10 @@ const ( // (e.g. how long to wait to query Google's 8.8.4.4 after 8.8.8.8). wellKnownHostBackupDelay = 200 * time.Millisecond + // udpRaceTimeout is the timeout after which we will start a DNS query + // over TCP while waiting for the UDP query to complete. + udpRaceTimeout = 2 * time.Second + // tcpQueryTimeout is the timeout for a DNS query performed over TCP. // It matches the default 5sec timeout of the 'dig' utility. tcpQueryTimeout = 5 * time.Second @@ -488,47 +494,97 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDe return nil, fmt.Errorf("tls:// resolvers not supported yet") } - ret, err = f.sendUDP(ctx, fq, rr) - if err != nil { - return nil, err - } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + isUDPQuery := fq.family == "udp" + skipTCP := skipTCPRetry() || (f.controlKnobs != nil && f.controlKnobs.DisableDNSForwarderTCPRetries.Load()) + + // Print logs about retries if this was because of a truncated response. + var explicitRetry atomic.Bool // true if truncated UDP response retried + defer func() { + if !explicitRetry.Load() { + return + } + if err == nil { + f.logf("forwarder.send(%q): successfully retried via TCP", rr.name.Addr) + } else { + f.logf("forwarder.send(%q): could not retry via TCP: %v", rr.name.Addr, err) + } + }() + + firstUDP := func(ctx context.Context) ([]byte, error) { + resp, err := f.sendUDP(ctx, fq, rr) + if err != nil { + return nil, err + } + if !truncatedFlagSet(resp) { + // Successful, non-truncated response; no retry. + return resp, nil + } - if !truncatedFlagSet(ret) { - // Successful, non-truncated response; return it. - return ret, nil - } - if fq.family == "udp" { // If this is a UDP query, return it regardless of whether the // response is truncated or not; the client can retry // communicating with tailscaled over TCP. There's no point // falling back to TCP for a truncated query if we can't return // the results to the client. - return ret, nil + if isUDPQuery { + return resp, nil + } + + if skipTCP { + // Envknob or control knob disabled the TCP retry behaviour; + // just return what we have. + return resp, nil + } + + // This is a TCP query from the client, and the UDP response + // from the upstream DNS server is truncated; map this to an + // error to cause our retry helper to immediately kick off the + // TCP retry. + explicitRetry.Store(true) + return nil, truncatedResponseError{resp} + } + thenTCP := func(ctx context.Context) ([]byte, error) { + // If we're skipping the TCP fallback, then wait until the + // context is canceled and return that error (i.e. not + // returning anything). + if skipTCP { + <-ctx.Done() + return nil, ctx.Err() + } + + return f.sendTCP(ctx, fq, rr) } - if skipTCPRetry() || (f.controlKnobs != nil && f.controlKnobs.DisableDNSForwarderTCPRetries.Load()) { - // Envknob or control knob disabled the TCP retry behaviour; - // just return what we have. - return ret, nil + + // If the input query is TCP, then don't have a timeout between + // starting UDP and TCP. + timeout := udpRaceTimeout + if !isUDPQuery { + timeout = 0 } - // Don't retry if our context is done. - if err := ctx.Err(); err != nil { - return nil, err + // Kick off the race between the UDP and TCP queries. + rh := race.New[[]byte](timeout, firstUDP, thenTCP) + resp, err := rh.Start(ctx) + if err == nil { + return resp, nil } - // Retry over TCP, best-effort; return the truncated UDP response if we - // cannot query via TCP. - if ret2, err2 := f.sendTCP(ctx, fq, rr); err2 == nil { - if verboseDNSForward() { - f.logf("forwarder.send(%q): successfully retried via TCP", rr.name.Addr) - } - return ret2, nil - } else if verboseDNSForward() { - f.logf("forwarder.send(%q): could not retry via TCP: %v", rr.name.Addr, err2) + // If we got a truncated UDP response, return that instead of an error. + var trErr truncatedResponseError + if errors.As(err, &trErr) { + return trErr.res, nil } - return ret, nil + return nil, err } +type truncatedResponseError struct { + res []byte +} + +func (tr truncatedResponseError) Error() string { return "response truncated" } + var errServerFailure = errors.New("response code indicates server issue") var errTxIDMismatch = errors.New("txid doesn't match") @@ -875,7 +931,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo } numErr++ if numErr == len(resolvers) { - if firstErr == errServerFailure { + if errors.Is(firstErr, errServerFailure) { res, err := servfailResponse(query) if err != nil { f.logf("building servfail response: %v", err) diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index b78e26c95..bfe2addc3 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "encoding/binary" + "errors" "flag" "fmt" "io" @@ -21,6 +22,7 @@ import ( "time" dns "golang.org/x/net/dns/dnsmessage" + "tailscale.com/control/controlknobs" "tailscale.com/envknob" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" @@ -253,7 +255,16 @@ func FuzzClampEDNSSize(f *testing.F) { }) } -func runDNSServer(tb testing.TB, response []byte, onRequest func(bool, []byte)) (port uint16) { +type testDNSServerOptions struct { + SkipUDP bool + SkipTCP bool +} + +func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, onRequest func(bool, []byte)) (port uint16) { + if opts != nil && opts.SkipUDP && opts.SkipTCP { + tb.Fatal("cannot skip both UDP and TCP servers") + } + tcpResponse := make([]byte, len(response)+2) binary.BigEndian.PutUint16(tcpResponse, uint16(len(response))) copy(tcpResponse[2:], response) @@ -327,17 +338,20 @@ func runDNSServer(tb testing.TB, response []byte, onRequest func(bool, []byte)) } var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - for { - conn, err := tcpLn.Accept() - if err != nil { - return + + if opts == nil || !opts.SkipTCP { + wg.Add(1) + go func() { + defer wg.Done() + for { + conn, err := tcpLn.Accept() + if err != nil { + return + } + go handleConn(conn) } - go handleConn(conn) - } - }() + }() + } handleUDP := func(addr netip.AddrPort, req []byte) { onRequest(false, req) @@ -346,19 +360,21 @@ func runDNSServer(tb testing.TB, response []byte, onRequest func(bool, []byte)) } } - wg.Add(1) - go func() { - defer wg.Done() - for { - buf := make([]byte, 65535) - n, addr, err := udpLn.ReadFromUDPAddrPort(buf) - if err != nil { - return + if opts == nil || !opts.SkipUDP { + wg.Add(1) + go func() { + defer wg.Done() + for { + buf := make([]byte, 65535) + n, addr, err := udpLn.ReadFromUDPAddrPort(buf) + if err != nil { + return + } + buf = buf[:n] + go handleUDP(addr, buf) } - buf = buf[:n] - go handleUDP(addr, buf) - } - }() + }() + } tb.Cleanup(func() { tcpLn.Close() @@ -369,84 +385,72 @@ func runDNSServer(tb testing.TB, response []byte, onRequest func(bool, []byte)) return } -func TestForwarderTCPFallback(t *testing.T) { +func enableDebug(tb testing.TB) { const debugKnob = "TS_DEBUG_DNS_FORWARD_SEND" oldVal := os.Getenv(debugKnob) envknob.Setenv(debugKnob, "true") - t.Cleanup(func() { envknob.Setenv(debugKnob, oldVal) }) - - const domain = "large-dns-response.tailscale.com." + tb.Cleanup(func() { envknob.Setenv(debugKnob, oldVal) }) +} - // Make a response that's very large, containing a bunch of localhost addresses. - largeResponse := func() []byte { - name := dns.MustNewName(domain) +func makeLargeResponse(tb testing.TB, domain string) (request, response []byte) { + name := dns.MustNewName(domain) - builder := dns.NewBuilder(nil, dns.Header{}) - builder.StartQuestions() - builder.Question(dns.Question{ + builder := dns.NewBuilder(nil, dns.Header{}) + builder.StartQuestions() + builder.Question(dns.Question{ + Name: name, + Type: dns.TypeA, + Class: dns.ClassINET, + }) + builder.StartAnswers() + for i := 0; i < 120; i++ { + builder.AResource(dns.ResourceHeader{ Name: name, - Type: dns.TypeA, Class: dns.ClassINET, + TTL: 300, + }, dns.AResource{ + A: [4]byte{127, 0, 0, byte(i)}, }) - builder.StartAnswers() - for i := 0; i < 120; i++ { - builder.AResource(dns.ResourceHeader{ - Name: name, - Class: dns.ClassINET, - TTL: 300, - }, dns.AResource{ - A: [4]byte{127, 0, 0, byte(i)}, - }) - } + } - msg, err := builder.Finish() - if err != nil { - t.Fatal(err) - } - return msg - }() - if len(largeResponse) <= maxResponseBytes { - t.Fatalf("got len(largeResponse)=%d, want > %d", len(largeResponse), maxResponseBytes) + var err error + response, err = builder.Finish() + if err != nil { + tb.Fatal(err) + } + if len(response) <= maxResponseBytes { + tb.Fatalf("got len(largeResponse)=%d, want > %d", len(response), maxResponseBytes) } // Our request is a single A query for the domain in the answer, above. - request := func() []byte { - builder := dns.NewBuilder(nil, dns.Header{}) - builder.StartQuestions() - builder.Question(dns.Question{ - Name: dns.MustNewName(domain), - Type: dns.TypeA, - Class: dns.ClassINET, - }) - msg, err := builder.Finish() - if err != nil { - t.Fatal(err) - } - return msg - }() - - var sawUDPRequest, sawTCPRequest atomic.Bool - port := runDNSServer(t, largeResponse, func(isTCP bool, gotRequest []byte) { - if isTCP { - sawTCPRequest.Store(true) - } else { - sawUDPRequest.Store(true) - } - - if !bytes.Equal(request, gotRequest) { - t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request) - } + builder = dns.NewBuilder(nil, dns.Header{}) + builder.StartQuestions() + builder.Question(dns.Question{ + Name: dns.MustNewName(domain), + Type: dns.TypeA, + Class: dns.ClassINET, }) + request, err = builder.Finish() + if err != nil { + tb.Fatal(err) + } - netMon, err := netmon.New(t.Logf) + return +} + +func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) ([]byte, error) { + netMon, err := netmon.New(tb.Logf) if err != nil { - t.Fatal(err) + tb.Fatal(err) } var dialer tsdial.Dialer dialer.SetNetMon(netMon) - fwd := newForwarder(t.Logf, netMon, nil, &dialer, nil) + fwd := newForwarder(tb.Logf, netMon, nil, &dialer, nil) + if modify != nil { + modify(fwd) + } fq := &forwardQuery{ txid: getTxID(request), @@ -459,10 +463,41 @@ func TestForwarderTCPFallback(t *testing.T) { name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)}, } - resp, err := fwd.send(context.Background(), fq, rr) + return fwd.send(context.Background(), fq, rr) +} + +func mustRunTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) []byte { + resp, err := runTestQuery(tb, port, request, modify) if err != nil { - t.Fatalf("error making request: %v", err) + tb.Fatalf("error making request: %v", err) } + return resp +} + +func TestForwarderTCPFallback(t *testing.T) { + enableDebug(t) + + const domain = "large-dns-response.tailscale.com." + + // Make a response that's very large, containing a bunch of localhost addresses. + request, largeResponse := makeLargeResponse(t, domain) + + var sawUDPRequest, sawTCPRequest atomic.Bool + port := runDNSServer(t, nil, largeResponse, func(isTCP bool, gotRequest []byte) { + if isTCP { + t.Logf("saw TCP request") + sawTCPRequest.Store(true) + } else { + t.Logf("saw UDP request") + sawUDPRequest.Store(true) + } + + if !bytes.Equal(request, gotRequest) { + t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request) + } + }) + + resp := mustRunTestQuery(t, port, request, nil) if !bytes.Equal(resp, largeResponse) { t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) } @@ -473,3 +508,141 @@ func TestForwarderTCPFallback(t *testing.T) { t.Errorf("DNS server never saw UDP request") } } + +// Test to ensure that if the UDP listener is unresponsive, we always make a +// TCP request even if we never get a response. +func TestForwarderTCPFallbackTimeout(t *testing.T) { + enableDebug(t) + + const domain = "large-dns-response.tailscale.com." + + // Make a response that's very large, containing a bunch of localhost addresses. + request, largeResponse := makeLargeResponse(t, domain) + + var sawTCPRequest atomic.Bool + opts := &testDNSServerOptions{SkipUDP: true} + port := runDNSServer(t, opts, largeResponse, func(isTCP bool, gotRequest []byte) { + if isTCP { + t.Logf("saw TCP request") + sawTCPRequest.Store(true) + } else { + t.Error("saw unexpected UDP request") + } + + if !bytes.Equal(request, gotRequest) { + t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request) + } + }) + + resp := mustRunTestQuery(t, port, request, nil) + if !bytes.Equal(resp, largeResponse) { + t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) + } + if !sawTCPRequest.Load() { + t.Errorf("DNS server never saw TCP request") + } +} + +func TestForwarderTCPFallbackDisabled(t *testing.T) { + enableDebug(t) + + const domain = "large-dns-response.tailscale.com." + + // Make a response that's very large, containing a bunch of localhost addresses. + request, largeResponse := makeLargeResponse(t, domain) + + var sawUDPRequest atomic.Bool + port := runDNSServer(t, nil, largeResponse, func(isTCP bool, gotRequest []byte) { + if isTCP { + t.Error("saw unexpected TCP request") + } else { + t.Logf("saw UDP request") + sawUDPRequest.Store(true) + } + + if !bytes.Equal(request, gotRequest) { + t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request) + } + }) + + resp := mustRunTestQuery(t, port, request, func(fwd *forwarder) { + // Disable retries for this test. + fwd.controlKnobs = &controlknobs.Knobs{} + fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true) + }) + + wantResp := append([]byte(nil), largeResponse[:maxResponseBytes]...) + + // Set the truncated flag on the expected response, since that's what we expect. + flags := binary.BigEndian.Uint16(wantResp[2:4]) + flags |= dnsFlagTruncated + binary.BigEndian.PutUint16(wantResp[2:4], flags) + + if !bytes.Equal(resp, wantResp) { + t.Errorf("invalid response\ngot (%d): %+v\nwant (%d): %+v", len(resp), resp, len(wantResp), wantResp) + } + if !sawUDPRequest.Load() { + t.Errorf("DNS server never saw UDP request") + } +} + +// Test to ensure that we propagate DNS errors +func TestForwarderTCPFallbackError(t *testing.T) { + enableDebug(t) + + const domain = "error-response.tailscale.com." + + // Our response is a SERVFAIL + response := func() []byte { + name := dns.MustNewName(domain) + + builder := dns.NewBuilder(nil, dns.Header{ + RCode: dns.RCodeServerFailure, + }) + builder.StartQuestions() + builder.Question(dns.Question{ + Name: name, + Type: dns.TypeA, + Class: dns.ClassINET, + }) + response, err := builder.Finish() + if err != nil { + t.Fatal(err) + } + return response + }() + + // Our request is a single A query for the domain in the answer, above. + request := func() []byte { + builder := dns.NewBuilder(nil, dns.Header{}) + builder.StartQuestions() + builder.Question(dns.Question{ + Name: dns.MustNewName(domain), + Type: dns.TypeA, + Class: dns.ClassINET, + }) + request, err := builder.Finish() + if err != nil { + t.Fatal(err) + } + return request + }() + + var sawRequest atomic.Bool + port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) { + sawRequest.Store(true) + if !bytes.Equal(request, gotRequest) { + t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request) + } + }) + + _, err := runTestQuery(t, port, request, nil) + if !sawRequest.Load() { + t.Error("did not see DNS request") + } + if err == nil { + t.Error("wanted error, got nil") + } else if !errors.Is(err, errServerFailure) { + t.Errorf("wanted errServerFailure, got: %v", err) + } +} diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index c135c92a1..882462012 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -1449,7 +1449,7 @@ func TestServfail(t *testing.T) { r.SetConfig(cfg) pkt, err := syncRespond(r, dnspacket("test.site.", dns.TypeA, noEdns)) - if err != errServerFailure { + if !errors.Is(err, errServerFailure) { t.Errorf("err = %v, want %v", err, errServerFailure) } diff --git a/util/race/race.go b/util/race/race.go new file mode 100644 index 000000000..041ce546f --- /dev/null +++ b/util/race/race.go @@ -0,0 +1,115 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package race contains a helper to "race" two functions, returning the first +// successful result. It also allows explicitly triggering the +// (possibly-waiting) second function when the first function returns an error +// or indicates that it should be retried. +package race + +import ( + "context" + "errors" + "time" +) + +type resultType int + +const ( + first resultType = iota + second +) + +// queryResult is an internal type for storing the result of a function call +type queryResult[T any] struct { + ty resultType + res T + err error +} + +// Func is the signature of a function to be called. +type Func[T any] func(context.Context) (T, error) + +// Race allows running two functions concurrently and returning the first +// non-error result returned. +type Race[T any] struct { + func1, func2 Func[T] + d time.Duration + results chan queryResult[T] + startFallback chan struct{} +} + +// New creates a new Race that, when Start is called, will immediately call +// func1 to obtain a result. After the timeout d or if triggered by an error +// response from func1, func2 will be called. +func New[T any](d time.Duration, func1, func2 Func[T]) *Race[T] { + ret := &Race[T]{ + func1: func1, + func2: func2, + d: d, + results: make(chan queryResult[T], 2), + startFallback: make(chan struct{}), + } + return ret +} + +// Start will start the "race" process, returning the first non-error result or +// the errors that occurred when calling func1 and/or func2. +func (rh *Race[T]) Start(ctx context.Context) (T, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // func1 is started immediately + go func() { + ret, err := rh.func1(ctx) + rh.results <- queryResult[T]{first, ret, err} + }() + + // func2 is started after a timeout + go func() { + wait := time.NewTimer(rh.d) + defer wait.Stop() + + // Wait for our timeout, trigger, or context to finish. + select { + case <-ctx.Done(): + // Nothing to do; we're done + var zero T + rh.results <- queryResult[T]{second, zero, ctx.Err()} + return + case <-rh.startFallback: + case <-wait.C: + } + + ret, err := rh.func2(ctx) + rh.results <- queryResult[T]{second, ret, err} + }() + + // For each possible result, get it off the channel. + var errs []error + for i := 0; i < 2; i++ { + res := <-rh.results + + // If this was an error, store it and hope that the other + // result gives us something. + if res.err != nil { + errs = append(errs, res.err) + + // Start the fallback function immediately if this is + // the first function's error, to avoid having + // to wait. + if res.ty == first { + close(rh.startFallback) + } + continue + } + + // Got a valid response! Return it. + return res.res, nil + } + + // If we get here, both raced functions failed. Return whatever errors + // we have, joined together. + var zero T + return zero, errors.Join(errs...) +} diff --git a/util/race/race_test.go b/util/race/race_test.go new file mode 100644 index 000000000..9c30e6adb --- /dev/null +++ b/util/race/race_test.go @@ -0,0 +1,89 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package race + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestRaceSuccess1(t *testing.T) { + const want = "success" + rh := New[string]( + 10*time.Second, + func(context.Context) (string, error) { + return want, nil + }, func(context.Context) (string, error) { + t.Fatal("should not be called") + return "", nil + }) + res, err := rh.Start(context.Background()) + if err != nil { + t.Fatal(err) + } + if res != want { + t.Errorf("got res=%q, want %q", res, want) + } +} + +func TestRaceRetry(t *testing.T) { + const want = "fallback" + rh := New[string]( + 10*time.Second, + func(context.Context) (string, error) { + return "", errors.New("some error") + }, func(context.Context) (string, error) { + return want, nil + }) + res, err := rh.Start(context.Background()) + if err != nil { + t.Fatal(err) + } + if res != want { + t.Errorf("got res=%q, want %q", res, want) + } +} + +func TestRaceTimeout(t *testing.T) { + const want = "fallback" + rh := New[string]( + 100*time.Millisecond, + func(ctx context.Context) (string, error) { + // Block forever + <-ctx.Done() + return "", ctx.Err() + }, func(context.Context) (string, error) { + return want, nil + }) + res, err := rh.Start(context.Background()) + if err != nil { + t.Fatal(err) + } + if res != want { + t.Errorf("got res=%q, want %q", res, want) + } +} + +func TestRaceError(t *testing.T) { + err1 := errors.New("error 1") + err2 := errors.New("error 2") + + rh := New[string]( + 100*time.Millisecond, + func(ctx context.Context) (string, error) { + return "", err1 + }, func(context.Context) (string, error) { + return "", err2 + }) + + _, err := rh.Start(context.Background()) + if !errors.Is(err, err1) { + t.Errorf("wanted err to contain err1; got %v", err) + } + if !errors.Is(err, err2) { + t.Errorf("wanted err to contain err2; got %v", err) + } +}