From b382161fe59fc60b8360bad320f2bb3e0b4fff7d Mon Sep 17 00:00:00 2001 From: Smitty Date: Sat, 18 Sep 2021 20:34:33 -0400 Subject: [PATCH] tsdns: don't forward transient DNS errors When a DNS server claims to be unable or unwilling to handle a request, instead of passing that refusal along to the client, just treat it as any other error trying to connect to the DNS server. This prevents DNS requests from failing based on if a server can respond with a transient error before another server is able to give an actual response. DNS requests only failing *sometimes* is really hard to find the cause of (#1033). Signed-off-by: Smitty --- net/dns/resolver/forwarder.go | 15 +++++++++++ net/dns/resolver/forwarder_test.go | 43 ++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 4b0be392c..66f7d43cd 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -74,6 +74,15 @@ func getTxID(packet []byte) txid { return txid(dnsid) } +func getRCode(packet []byte) dns.RCode { + if len(packet) < headerBytes { + // treat invalid packets as a refusal + return dns.RCode(5) + } + // get bottom 4 bits of 3rd byte + return dns.RCode(packet[3] & 0x0F) +} + // clampEDNSSize attempts to limit the maximum EDNS response size. This is not // an exhaustive solution, instead only easy cases are currently handled in the // interest of speed and reduced complexity. Only OPT records at the very end of @@ -455,6 +464,12 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDe if txid != fq.txid { return nil, errors.New("txid doesn't match") } + rcode := getRCode(out) + // don't forward transient errors back to the client when the server fails + if rcode == dns.RCodeServerFailure { + f.logf("recv: response code indicating server failure: %d", rcode) + return nil, errors.New("response code indicates server issue") + } if truncated { const dnsFlagTruncated = 0x200 diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index a3b3317b3..8d14e23aa 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + dns "golang.org/x/net/dns/dnsmessage" "tailscale.com/types/dnstype" ) @@ -97,3 +98,45 @@ func TestResolversWithDelays(t *testing.T) { } } + +func TestGetRCode(t *testing.T) { + tests := []struct { + name string + packet []byte + want dns.RCode + }{ + { + name: "empty", + packet: []byte{}, + want: dns.RCode(5), + }, + { + name: "too-short", + packet: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + want: dns.RCode(5), + }, + { + name: "noerror", + packet: []byte{0xC4, 0xFE, 0x81, 0xA0, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01}, + want: dns.RCode(0), + }, + { + name: "refused", + packet: []byte{0xee, 0xa1, 0x81, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, + want: dns.RCode(5), + }, + { + name: "nxdomain", + packet: []byte{0x34, 0xf4, 0x81, 0x83, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01}, + want: dns.RCode(3), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := getRCode(tt.packet) + if got != tt.want { + t.Errorf("got %d; want %d", got, tt.want) + } + }) + } +}