mirror of https://github.com/tailscale/tailscale/
net/dns/resolver: race UDP and TCP queries (#9544)
Instead of just falling back to making a TCP query to an upstream DNS server when the UDP query returns a truncated query, also start a TCP query in parallel with the UDP query after a given race timeout. This ensures that if the upstream DNS server does not reply over UDP (or if the response packet is blocked, or there's an error), we can still make queries if the server replies to TCP queries. This also adds a new package, util/race, to contain the logic required for racing two different functions and returning the first non-error answer. Updates tailscale/corp#14809 Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: I4311702016c1093b1beaa31b135da1def6d86316pull/9641/head
parent
eb22c0dfc7
commit
286c6ce27c
@ -0,0 +1,115 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
// Package race contains a helper to "race" two functions, returning the first
|
||||||
|
// successful result. It also allows explicitly triggering the
|
||||||
|
// (possibly-waiting) second function when the first function returns an error
|
||||||
|
// or indicates that it should be retried.
|
||||||
|
package race
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type resultType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
first resultType = iota
|
||||||
|
second
|
||||||
|
)
|
||||||
|
|
||||||
|
// queryResult is an internal type for storing the result of a function call
|
||||||
|
type queryResult[T any] struct {
|
||||||
|
ty resultType
|
||||||
|
res T
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Func is the signature of a function to be called.
|
||||||
|
type Func[T any] func(context.Context) (T, error)
|
||||||
|
|
||||||
|
// Race allows running two functions concurrently and returning the first
|
||||||
|
// non-error result returned.
|
||||||
|
type Race[T any] struct {
|
||||||
|
func1, func2 Func[T]
|
||||||
|
d time.Duration
|
||||||
|
results chan queryResult[T]
|
||||||
|
startFallback chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new Race that, when Start is called, will immediately call
|
||||||
|
// func1 to obtain a result. After the timeout d or if triggered by an error
|
||||||
|
// response from func1, func2 will be called.
|
||||||
|
func New[T any](d time.Duration, func1, func2 Func[T]) *Race[T] {
|
||||||
|
ret := &Race[T]{
|
||||||
|
func1: func1,
|
||||||
|
func2: func2,
|
||||||
|
d: d,
|
||||||
|
results: make(chan queryResult[T], 2),
|
||||||
|
startFallback: make(chan struct{}),
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start will start the "race" process, returning the first non-error result or
|
||||||
|
// the errors that occurred when calling func1 and/or func2.
|
||||||
|
func (rh *Race[T]) Start(ctx context.Context) (T, error) {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// func1 is started immediately
|
||||||
|
go func() {
|
||||||
|
ret, err := rh.func1(ctx)
|
||||||
|
rh.results <- queryResult[T]{first, ret, err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// func2 is started after a timeout
|
||||||
|
go func() {
|
||||||
|
wait := time.NewTimer(rh.d)
|
||||||
|
defer wait.Stop()
|
||||||
|
|
||||||
|
// Wait for our timeout, trigger, or context to finish.
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Nothing to do; we're done
|
||||||
|
var zero T
|
||||||
|
rh.results <- queryResult[T]{second, zero, ctx.Err()}
|
||||||
|
return
|
||||||
|
case <-rh.startFallback:
|
||||||
|
case <-wait.C:
|
||||||
|
}
|
||||||
|
|
||||||
|
ret, err := rh.func2(ctx)
|
||||||
|
rh.results <- queryResult[T]{second, ret, err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// For each possible result, get it off the channel.
|
||||||
|
var errs []error
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
res := <-rh.results
|
||||||
|
|
||||||
|
// If this was an error, store it and hope that the other
|
||||||
|
// result gives us something.
|
||||||
|
if res.err != nil {
|
||||||
|
errs = append(errs, res.err)
|
||||||
|
|
||||||
|
// Start the fallback function immediately if this is
|
||||||
|
// the first function's error, to avoid having
|
||||||
|
// to wait.
|
||||||
|
if res.ty == first {
|
||||||
|
close(rh.startFallback)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Got a valid response! Return it.
|
||||||
|
return res.res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we get here, both raced functions failed. Return whatever errors
|
||||||
|
// we have, joined together.
|
||||||
|
var zero T
|
||||||
|
return zero, errors.Join(errs...)
|
||||||
|
}
|
@ -0,0 +1,89 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
package race
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRaceSuccess1(t *testing.T) {
|
||||||
|
const want = "success"
|
||||||
|
rh := New[string](
|
||||||
|
10*time.Second,
|
||||||
|
func(context.Context) (string, error) {
|
||||||
|
return want, nil
|
||||||
|
}, func(context.Context) (string, error) {
|
||||||
|
t.Fatal("should not be called")
|
||||||
|
return "", nil
|
||||||
|
})
|
||||||
|
res, err := rh.Start(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if res != want {
|
||||||
|
t.Errorf("got res=%q, want %q", res, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRaceRetry(t *testing.T) {
|
||||||
|
const want = "fallback"
|
||||||
|
rh := New[string](
|
||||||
|
10*time.Second,
|
||||||
|
func(context.Context) (string, error) {
|
||||||
|
return "", errors.New("some error")
|
||||||
|
}, func(context.Context) (string, error) {
|
||||||
|
return want, nil
|
||||||
|
})
|
||||||
|
res, err := rh.Start(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if res != want {
|
||||||
|
t.Errorf("got res=%q, want %q", res, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRaceTimeout(t *testing.T) {
|
||||||
|
const want = "fallback"
|
||||||
|
rh := New[string](
|
||||||
|
100*time.Millisecond,
|
||||||
|
func(ctx context.Context) (string, error) {
|
||||||
|
// Block forever
|
||||||
|
<-ctx.Done()
|
||||||
|
return "", ctx.Err()
|
||||||
|
}, func(context.Context) (string, error) {
|
||||||
|
return want, nil
|
||||||
|
})
|
||||||
|
res, err := rh.Start(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if res != want {
|
||||||
|
t.Errorf("got res=%q, want %q", res, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRaceError(t *testing.T) {
|
||||||
|
err1 := errors.New("error 1")
|
||||||
|
err2 := errors.New("error 2")
|
||||||
|
|
||||||
|
rh := New[string](
|
||||||
|
100*time.Millisecond,
|
||||||
|
func(ctx context.Context) (string, error) {
|
||||||
|
return "", err1
|
||||||
|
}, func(context.Context) (string, error) {
|
||||||
|
return "", err2
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := rh.Start(context.Background())
|
||||||
|
if !errors.Is(err, err1) {
|
||||||
|
t.Errorf("wanted err to contain err1; got %v", err)
|
||||||
|
}
|
||||||
|
if !errors.Is(err, err2) {
|
||||||
|
t.Errorf("wanted err to contain err2; got %v", err)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue