net/dns/resolver: race UDP and TCP queries (#9544)

Instead of just falling back to making a TCP query to an upstream DNS
server when the UDP query returns a truncated query, also start a TCP
query in parallel with the UDP query after a given race timeout. This
ensures that if the upstream DNS server does not reply over UDP (or if
the response packet is blocked, or there's an error), we can still make
queries if the server replies to TCP queries.

This also adds a new package, util/race, to contain the logic required for
racing two different functions and returning the first non-error answer.

Updates tailscale/corp#14809

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: I4311702016c1093b1beaa31b135da1def6d86316
pull/9641/head
Andrew Dunham 1 year ago committed by GitHub
parent eb22c0dfc7
commit 286c6ce27c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -346,6 +346,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
W 💣 tailscale.com/util/osdiag/internal/wsc from tailscale.com/util/osdiag W 💣 tailscale.com/util/osdiag/internal/wsc from tailscale.com/util/osdiag
tailscale.com/util/osshare from tailscale.com/ipn/ipnlocal+ tailscale.com/util/osshare from tailscale.com/ipn/ipnlocal+
W tailscale.com/util/pidowner from tailscale.com/ipn/ipnauth W tailscale.com/util/pidowner from tailscale.com/ipn/ipnauth
tailscale.com/util/race from tailscale.com/net/dns/resolver
tailscale.com/util/racebuild from tailscale.com/logpolicy tailscale.com/util/racebuild from tailscale.com/logpolicy
tailscale.com/util/rands from tailscale.com/ipn/ipnlocal+ tailscale.com/util/rands from tailscale.com/ipn/ipnlocal+
tailscale.com/util/ringbuffer from tailscale.com/wgengine/magicsock tailscale.com/util/ringbuffer from tailscale.com/wgengine/magicsock

@ -18,6 +18,7 @@ import (
"sort" "sort"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
dns "golang.org/x/net/dns/dnsmessage" dns "golang.org/x/net/dns/dnsmessage"
@ -35,6 +36,7 @@ import (
"tailscale.com/types/nettype" "tailscale.com/types/nettype"
"tailscale.com/util/cloudenv" "tailscale.com/util/cloudenv"
"tailscale.com/util/dnsname" "tailscale.com/util/dnsname"
"tailscale.com/util/race"
"tailscale.com/version" "tailscale.com/version"
) )
@ -70,6 +72,10 @@ const (
// (e.g. how long to wait to query Google's 8.8.4.4 after 8.8.8.8). // (e.g. how long to wait to query Google's 8.8.4.4 after 8.8.8.8).
wellKnownHostBackupDelay = 200 * time.Millisecond wellKnownHostBackupDelay = 200 * time.Millisecond
// udpRaceTimeout is the timeout after which we will start a DNS query
// over TCP while waiting for the UDP query to complete.
udpRaceTimeout = 2 * time.Second
// tcpQueryTimeout is the timeout for a DNS query performed over TCP. // tcpQueryTimeout is the timeout for a DNS query performed over TCP.
// It matches the default 5sec timeout of the 'dig' utility. // It matches the default 5sec timeout of the 'dig' utility.
tcpQueryTimeout = 5 * time.Second tcpQueryTimeout = 5 * time.Second
@ -488,47 +494,97 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDe
return nil, fmt.Errorf("tls:// resolvers not supported yet") return nil, fmt.Errorf("tls:// resolvers not supported yet")
} }
ret, err = f.sendUDP(ctx, fq, rr) ctx, cancel := context.WithCancel(ctx)
if err != nil { defer cancel()
return nil, err
} isUDPQuery := fq.family == "udp"
skipTCP := skipTCPRetry() || (f.controlKnobs != nil && f.controlKnobs.DisableDNSForwarderTCPRetries.Load())
// Print logs about retries if this was because of a truncated response.
var explicitRetry atomic.Bool // true if truncated UDP response retried
defer func() {
if !explicitRetry.Load() {
return
}
if err == nil {
f.logf("forwarder.send(%q): successfully retried via TCP", rr.name.Addr)
} else {
f.logf("forwarder.send(%q): could not retry via TCP: %v", rr.name.Addr, err)
}
}()
firstUDP := func(ctx context.Context) ([]byte, error) {
resp, err := f.sendUDP(ctx, fq, rr)
if err != nil {
return nil, err
}
if !truncatedFlagSet(resp) {
// Successful, non-truncated response; no retry.
return resp, nil
}
if !truncatedFlagSet(ret) {
// Successful, non-truncated response; return it.
return ret, nil
}
if fq.family == "udp" {
// If this is a UDP query, return it regardless of whether the // If this is a UDP query, return it regardless of whether the
// response is truncated or not; the client can retry // response is truncated or not; the client can retry
// communicating with tailscaled over TCP. There's no point // communicating with tailscaled over TCP. There's no point
// falling back to TCP for a truncated query if we can't return // falling back to TCP for a truncated query if we can't return
// the results to the client. // the results to the client.
return ret, nil if isUDPQuery {
return resp, nil
}
if skipTCP {
// Envknob or control knob disabled the TCP retry behaviour;
// just return what we have.
return resp, nil
}
// This is a TCP query from the client, and the UDP response
// from the upstream DNS server is truncated; map this to an
// error to cause our retry helper to immediately kick off the
// TCP retry.
explicitRetry.Store(true)
return nil, truncatedResponseError{resp}
}
thenTCP := func(ctx context.Context) ([]byte, error) {
// If we're skipping the TCP fallback, then wait until the
// context is canceled and return that error (i.e. not
// returning anything).
if skipTCP {
<-ctx.Done()
return nil, ctx.Err()
}
return f.sendTCP(ctx, fq, rr)
} }
if skipTCPRetry() || (f.controlKnobs != nil && f.controlKnobs.DisableDNSForwarderTCPRetries.Load()) {
// Envknob or control knob disabled the TCP retry behaviour; // If the input query is TCP, then don't have a timeout between
// just return what we have. // starting UDP and TCP.
return ret, nil timeout := udpRaceTimeout
if !isUDPQuery {
timeout = 0
} }
// Don't retry if our context is done. // Kick off the race between the UDP and TCP queries.
if err := ctx.Err(); err != nil { rh := race.New[[]byte](timeout, firstUDP, thenTCP)
return nil, err resp, err := rh.Start(ctx)
if err == nil {
return resp, nil
} }
// Retry over TCP, best-effort; return the truncated UDP response if we // If we got a truncated UDP response, return that instead of an error.
// cannot query via TCP. var trErr truncatedResponseError
if ret2, err2 := f.sendTCP(ctx, fq, rr); err2 == nil { if errors.As(err, &trErr) {
if verboseDNSForward() { return trErr.res, nil
f.logf("forwarder.send(%q): successfully retried via TCP", rr.name.Addr)
}
return ret2, nil
} else if verboseDNSForward() {
f.logf("forwarder.send(%q): could not retry via TCP: %v", rr.name.Addr, err2)
} }
return ret, nil return nil, err
} }
type truncatedResponseError struct {
res []byte
}
func (tr truncatedResponseError) Error() string { return "response truncated" }
var errServerFailure = errors.New("response code indicates server issue") var errServerFailure = errors.New("response code indicates server issue")
var errTxIDMismatch = errors.New("txid doesn't match") var errTxIDMismatch = errors.New("txid doesn't match")
@ -875,7 +931,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
} }
numErr++ numErr++
if numErr == len(resolvers) { if numErr == len(resolvers) {
if firstErr == errServerFailure { if errors.Is(firstErr, errServerFailure) {
res, err := servfailResponse(query) res, err := servfailResponse(query)
if err != nil { if err != nil {
f.logf("building servfail response: %v", err) f.logf("building servfail response: %v", err)

@ -7,6 +7,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@ -21,6 +22,7 @@ import (
"time" "time"
dns "golang.org/x/net/dns/dnsmessage" dns "golang.org/x/net/dns/dnsmessage"
"tailscale.com/control/controlknobs"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/net/netmon" "tailscale.com/net/netmon"
"tailscale.com/net/tsdial" "tailscale.com/net/tsdial"
@ -253,7 +255,16 @@ func FuzzClampEDNSSize(f *testing.F) {
}) })
} }
func runDNSServer(tb testing.TB, response []byte, onRequest func(bool, []byte)) (port uint16) { type testDNSServerOptions struct {
SkipUDP bool
SkipTCP bool
}
func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, onRequest func(bool, []byte)) (port uint16) {
if opts != nil && opts.SkipUDP && opts.SkipTCP {
tb.Fatal("cannot skip both UDP and TCP servers")
}
tcpResponse := make([]byte, len(response)+2) tcpResponse := make([]byte, len(response)+2)
binary.BigEndian.PutUint16(tcpResponse, uint16(len(response))) binary.BigEndian.PutUint16(tcpResponse, uint16(len(response)))
copy(tcpResponse[2:], response) copy(tcpResponse[2:], response)
@ -327,17 +338,20 @@ func runDNSServer(tb testing.TB, response []byte, onRequest func(bool, []byte))
} }
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1)
go func() { if opts == nil || !opts.SkipTCP {
defer wg.Done() wg.Add(1)
for { go func() {
conn, err := tcpLn.Accept() defer wg.Done()
if err != nil { for {
return conn, err := tcpLn.Accept()
if err != nil {
return
}
go handleConn(conn)
} }
go handleConn(conn) }()
} }
}()
handleUDP := func(addr netip.AddrPort, req []byte) { handleUDP := func(addr netip.AddrPort, req []byte) {
onRequest(false, req) onRequest(false, req)
@ -346,19 +360,21 @@ func runDNSServer(tb testing.TB, response []byte, onRequest func(bool, []byte))
} }
} }
wg.Add(1) if opts == nil || !opts.SkipUDP {
go func() { wg.Add(1)
defer wg.Done() go func() {
for { defer wg.Done()
buf := make([]byte, 65535) for {
n, addr, err := udpLn.ReadFromUDPAddrPort(buf) buf := make([]byte, 65535)
if err != nil { n, addr, err := udpLn.ReadFromUDPAddrPort(buf)
return if err != nil {
return
}
buf = buf[:n]
go handleUDP(addr, buf)
} }
buf = buf[:n] }()
go handleUDP(addr, buf) }
}
}()
tb.Cleanup(func() { tb.Cleanup(func() {
tcpLn.Close() tcpLn.Close()
@ -369,84 +385,72 @@ func runDNSServer(tb testing.TB, response []byte, onRequest func(bool, []byte))
return return
} }
func TestForwarderTCPFallback(t *testing.T) { func enableDebug(tb testing.TB) {
const debugKnob = "TS_DEBUG_DNS_FORWARD_SEND" const debugKnob = "TS_DEBUG_DNS_FORWARD_SEND"
oldVal := os.Getenv(debugKnob) oldVal := os.Getenv(debugKnob)
envknob.Setenv(debugKnob, "true") envknob.Setenv(debugKnob, "true")
t.Cleanup(func() { envknob.Setenv(debugKnob, oldVal) }) tb.Cleanup(func() { envknob.Setenv(debugKnob, oldVal) })
}
const domain = "large-dns-response.tailscale.com."
// Make a response that's very large, containing a bunch of localhost addresses. func makeLargeResponse(tb testing.TB, domain string) (request, response []byte) {
largeResponse := func() []byte { name := dns.MustNewName(domain)
name := dns.MustNewName(domain)
builder := dns.NewBuilder(nil, dns.Header{}) builder := dns.NewBuilder(nil, dns.Header{})
builder.StartQuestions() builder.StartQuestions()
builder.Question(dns.Question{ builder.Question(dns.Question{
Name: name,
Type: dns.TypeA,
Class: dns.ClassINET,
})
builder.StartAnswers()
for i := 0; i < 120; i++ {
builder.AResource(dns.ResourceHeader{
Name: name, Name: name,
Type: dns.TypeA,
Class: dns.ClassINET, Class: dns.ClassINET,
TTL: 300,
}, dns.AResource{
A: [4]byte{127, 0, 0, byte(i)},
}) })
builder.StartAnswers() }
for i := 0; i < 120; i++ {
builder.AResource(dns.ResourceHeader{
Name: name,
Class: dns.ClassINET,
TTL: 300,
}, dns.AResource{
A: [4]byte{127, 0, 0, byte(i)},
})
}
msg, err := builder.Finish() var err error
if err != nil { response, err = builder.Finish()
t.Fatal(err) if err != nil {
} tb.Fatal(err)
return msg }
}() if len(response) <= maxResponseBytes {
if len(largeResponse) <= maxResponseBytes { tb.Fatalf("got len(largeResponse)=%d, want > %d", len(response), maxResponseBytes)
t.Fatalf("got len(largeResponse)=%d, want > %d", len(largeResponse), maxResponseBytes)
} }
// Our request is a single A query for the domain in the answer, above. // Our request is a single A query for the domain in the answer, above.
request := func() []byte { builder = dns.NewBuilder(nil, dns.Header{})
builder := dns.NewBuilder(nil, dns.Header{}) builder.StartQuestions()
builder.StartQuestions() builder.Question(dns.Question{
builder.Question(dns.Question{ Name: dns.MustNewName(domain),
Name: dns.MustNewName(domain), Type: dns.TypeA,
Type: dns.TypeA, Class: dns.ClassINET,
Class: dns.ClassINET,
})
msg, err := builder.Finish()
if err != nil {
t.Fatal(err)
}
return msg
}()
var sawUDPRequest, sawTCPRequest atomic.Bool
port := runDNSServer(t, largeResponse, func(isTCP bool, gotRequest []byte) {
if isTCP {
sawTCPRequest.Store(true)
} else {
sawUDPRequest.Store(true)
}
if !bytes.Equal(request, gotRequest) {
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
}
}) })
request, err = builder.Finish()
if err != nil {
tb.Fatal(err)
}
netMon, err := netmon.New(t.Logf) return
}
func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) ([]byte, error) {
netMon, err := netmon.New(tb.Logf)
if err != nil { if err != nil {
t.Fatal(err) tb.Fatal(err)
} }
var dialer tsdial.Dialer var dialer tsdial.Dialer
dialer.SetNetMon(netMon) dialer.SetNetMon(netMon)
fwd := newForwarder(t.Logf, netMon, nil, &dialer, nil) fwd := newForwarder(tb.Logf, netMon, nil, &dialer, nil)
if modify != nil {
modify(fwd)
}
fq := &forwardQuery{ fq := &forwardQuery{
txid: getTxID(request), txid: getTxID(request),
@ -459,10 +463,41 @@ func TestForwarderTCPFallback(t *testing.T) {
name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)}, name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)},
} }
resp, err := fwd.send(context.Background(), fq, rr) return fwd.send(context.Background(), fq, rr)
}
func mustRunTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) []byte {
resp, err := runTestQuery(tb, port, request, modify)
if err != nil { if err != nil {
t.Fatalf("error making request: %v", err) tb.Fatalf("error making request: %v", err)
} }
return resp
}
func TestForwarderTCPFallback(t *testing.T) {
enableDebug(t)
const domain = "large-dns-response.tailscale.com."
// Make a response that's very large, containing a bunch of localhost addresses.
request, largeResponse := makeLargeResponse(t, domain)
var sawUDPRequest, sawTCPRequest atomic.Bool
port := runDNSServer(t, nil, largeResponse, func(isTCP bool, gotRequest []byte) {
if isTCP {
t.Logf("saw TCP request")
sawTCPRequest.Store(true)
} else {
t.Logf("saw UDP request")
sawUDPRequest.Store(true)
}
if !bytes.Equal(request, gotRequest) {
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
}
})
resp := mustRunTestQuery(t, port, request, nil)
if !bytes.Equal(resp, largeResponse) { if !bytes.Equal(resp, largeResponse) {
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
} }
@ -473,3 +508,141 @@ func TestForwarderTCPFallback(t *testing.T) {
t.Errorf("DNS server never saw UDP request") t.Errorf("DNS server never saw UDP request")
} }
} }
// Test to ensure that if the UDP listener is unresponsive, we always make a
// TCP request even if we never get a response.
func TestForwarderTCPFallbackTimeout(t *testing.T) {
enableDebug(t)
const domain = "large-dns-response.tailscale.com."
// Make a response that's very large, containing a bunch of localhost addresses.
request, largeResponse := makeLargeResponse(t, domain)
var sawTCPRequest atomic.Bool
opts := &testDNSServerOptions{SkipUDP: true}
port := runDNSServer(t, opts, largeResponse, func(isTCP bool, gotRequest []byte) {
if isTCP {
t.Logf("saw TCP request")
sawTCPRequest.Store(true)
} else {
t.Error("saw unexpected UDP request")
}
if !bytes.Equal(request, gotRequest) {
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
}
})
resp := mustRunTestQuery(t, port, request, nil)
if !bytes.Equal(resp, largeResponse) {
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
}
if !sawTCPRequest.Load() {
t.Errorf("DNS server never saw TCP request")
}
}
func TestForwarderTCPFallbackDisabled(t *testing.T) {
enableDebug(t)
const domain = "large-dns-response.tailscale.com."
// Make a response that's very large, containing a bunch of localhost addresses.
request, largeResponse := makeLargeResponse(t, domain)
var sawUDPRequest atomic.Bool
port := runDNSServer(t, nil, largeResponse, func(isTCP bool, gotRequest []byte) {
if isTCP {
t.Error("saw unexpected TCP request")
} else {
t.Logf("saw UDP request")
sawUDPRequest.Store(true)
}
if !bytes.Equal(request, gotRequest) {
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
}
})
resp := mustRunTestQuery(t, port, request, func(fwd *forwarder) {
// Disable retries for this test.
fwd.controlKnobs = &controlknobs.Knobs{}
fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true)
})
wantResp := append([]byte(nil), largeResponse[:maxResponseBytes]...)
// Set the truncated flag on the expected response, since that's what we expect.
flags := binary.BigEndian.Uint16(wantResp[2:4])
flags |= dnsFlagTruncated
binary.BigEndian.PutUint16(wantResp[2:4], flags)
if !bytes.Equal(resp, wantResp) {
t.Errorf("invalid response\ngot (%d): %+v\nwant (%d): %+v", len(resp), resp, len(wantResp), wantResp)
}
if !sawUDPRequest.Load() {
t.Errorf("DNS server never saw UDP request")
}
}
// Test to ensure that we propagate DNS errors
func TestForwarderTCPFallbackError(t *testing.T) {
enableDebug(t)
const domain = "error-response.tailscale.com."
// Our response is a SERVFAIL
response := func() []byte {
name := dns.MustNewName(domain)
builder := dns.NewBuilder(nil, dns.Header{
RCode: dns.RCodeServerFailure,
})
builder.StartQuestions()
builder.Question(dns.Question{
Name: name,
Type: dns.TypeA,
Class: dns.ClassINET,
})
response, err := builder.Finish()
if err != nil {
t.Fatal(err)
}
return response
}()
// Our request is a single A query for the domain in the answer, above.
request := func() []byte {
builder := dns.NewBuilder(nil, dns.Header{})
builder.StartQuestions()
builder.Question(dns.Question{
Name: dns.MustNewName(domain),
Type: dns.TypeA,
Class: dns.ClassINET,
})
request, err := builder.Finish()
if err != nil {
t.Fatal(err)
}
return request
}()
var sawRequest atomic.Bool
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
sawRequest.Store(true)
if !bytes.Equal(request, gotRequest) {
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
}
})
_, err := runTestQuery(t, port, request, nil)
if !sawRequest.Load() {
t.Error("did not see DNS request")
}
if err == nil {
t.Error("wanted error, got nil")
} else if !errors.Is(err, errServerFailure) {
t.Errorf("wanted errServerFailure, got: %v", err)
}
}

@ -1449,7 +1449,7 @@ func TestServfail(t *testing.T) {
r.SetConfig(cfg) r.SetConfig(cfg)
pkt, err := syncRespond(r, dnspacket("test.site.", dns.TypeA, noEdns)) pkt, err := syncRespond(r, dnspacket("test.site.", dns.TypeA, noEdns))
if err != errServerFailure { if !errors.Is(err, errServerFailure) {
t.Errorf("err = %v, want %v", err, errServerFailure) t.Errorf("err = %v, want %v", err, errServerFailure)
} }

@ -0,0 +1,115 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package race contains a helper to "race" two functions, returning the first
// successful result. It also allows explicitly triggering the
// (possibly-waiting) second function when the first function returns an error
// or indicates that it should be retried.
package race
import (
"context"
"errors"
"time"
)
type resultType int
const (
first resultType = iota
second
)
// queryResult is an internal type for storing the result of a function call
type queryResult[T any] struct {
ty resultType
res T
err error
}
// Func is the signature of a function to be called.
type Func[T any] func(context.Context) (T, error)
// Race allows running two functions concurrently and returning the first
// non-error result returned.
type Race[T any] struct {
func1, func2 Func[T]
d time.Duration
results chan queryResult[T]
startFallback chan struct{}
}
// New creates a new Race that, when Start is called, will immediately call
// func1 to obtain a result. After the timeout d or if triggered by an error
// response from func1, func2 will be called.
func New[T any](d time.Duration, func1, func2 Func[T]) *Race[T] {
ret := &Race[T]{
func1: func1,
func2: func2,
d: d,
results: make(chan queryResult[T], 2),
startFallback: make(chan struct{}),
}
return ret
}
// Start will start the "race" process, returning the first non-error result or
// the errors that occurred when calling func1 and/or func2.
func (rh *Race[T]) Start(ctx context.Context) (T, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// func1 is started immediately
go func() {
ret, err := rh.func1(ctx)
rh.results <- queryResult[T]{first, ret, err}
}()
// func2 is started after a timeout
go func() {
wait := time.NewTimer(rh.d)
defer wait.Stop()
// Wait for our timeout, trigger, or context to finish.
select {
case <-ctx.Done():
// Nothing to do; we're done
var zero T
rh.results <- queryResult[T]{second, zero, ctx.Err()}
return
case <-rh.startFallback:
case <-wait.C:
}
ret, err := rh.func2(ctx)
rh.results <- queryResult[T]{second, ret, err}
}()
// For each possible result, get it off the channel.
var errs []error
for i := 0; i < 2; i++ {
res := <-rh.results
// If this was an error, store it and hope that the other
// result gives us something.
if res.err != nil {
errs = append(errs, res.err)
// Start the fallback function immediately if this is
// the first function's error, to avoid having
// to wait.
if res.ty == first {
close(rh.startFallback)
}
continue
}
// Got a valid response! Return it.
return res.res, nil
}
// If we get here, both raced functions failed. Return whatever errors
// we have, joined together.
var zero T
return zero, errors.Join(errs...)
}

@ -0,0 +1,89 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package race
import (
"context"
"errors"
"testing"
"time"
)
func TestRaceSuccess1(t *testing.T) {
const want = "success"
rh := New[string](
10*time.Second,
func(context.Context) (string, error) {
return want, nil
}, func(context.Context) (string, error) {
t.Fatal("should not be called")
return "", nil
})
res, err := rh.Start(context.Background())
if err != nil {
t.Fatal(err)
}
if res != want {
t.Errorf("got res=%q, want %q", res, want)
}
}
func TestRaceRetry(t *testing.T) {
const want = "fallback"
rh := New[string](
10*time.Second,
func(context.Context) (string, error) {
return "", errors.New("some error")
}, func(context.Context) (string, error) {
return want, nil
})
res, err := rh.Start(context.Background())
if err != nil {
t.Fatal(err)
}
if res != want {
t.Errorf("got res=%q, want %q", res, want)
}
}
func TestRaceTimeout(t *testing.T) {
const want = "fallback"
rh := New[string](
100*time.Millisecond,
func(ctx context.Context) (string, error) {
// Block forever
<-ctx.Done()
return "", ctx.Err()
}, func(context.Context) (string, error) {
return want, nil
})
res, err := rh.Start(context.Background())
if err != nil {
t.Fatal(err)
}
if res != want {
t.Errorf("got res=%q, want %q", res, want)
}
}
func TestRaceError(t *testing.T) {
err1 := errors.New("error 1")
err2 := errors.New("error 2")
rh := New[string](
100*time.Millisecond,
func(ctx context.Context) (string, error) {
return "", err1
}, func(context.Context) (string, error) {
return "", err2
})
_, err := rh.Start(context.Background())
if !errors.Is(err, err1) {
t.Errorf("wanted err to contain err1; got %v", err)
}
if !errors.Is(err, err2) {
t.Errorf("wanted err to contain err2; got %v", err)
}
}
Loading…
Cancel
Save