diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 4b0be392c..66f7d43cd 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -74,6 +74,15 @@ func getTxID(packet []byte) txid { return txid(dnsid) } +func getRCode(packet []byte) dns.RCode { + if len(packet) < headerBytes { + // treat invalid packets as a refusal + return dns.RCode(5) + } + // get bottom 4 bits of 3rd byte + return dns.RCode(packet[3] & 0x0F) +} + // clampEDNSSize attempts to limit the maximum EDNS response size. This is not // an exhaustive solution, instead only easy cases are currently handled in the // interest of speed and reduced complexity. Only OPT records at the very end of @@ -455,6 +464,12 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDe if txid != fq.txid { return nil, errors.New("txid doesn't match") } + rcode := getRCode(out) + // don't forward transient errors back to the client when the server fails + if rcode == dns.RCodeServerFailure { + f.logf("recv: response code indicating server failure: %d", rcode) + return nil, errors.New("response code indicates server issue") + } if truncated { const dnsFlagTruncated = 0x200 diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index a3b3317b3..8d14e23aa 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + dns "golang.org/x/net/dns/dnsmessage" "tailscale.com/types/dnstype" ) @@ -97,3 +98,45 @@ func TestResolversWithDelays(t *testing.T) { } } + +func TestGetRCode(t *testing.T) { + tests := []struct { + name string + packet []byte + want dns.RCode + }{ + { + name: "empty", + packet: []byte{}, + want: dns.RCode(5), + }, + { + name: "too-short", + packet: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + want: dns.RCode(5), + }, + { + name: "noerror", + packet: []byte{0xC4, 0xFE, 0x81, 0xA0, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01}, + want: dns.RCode(0), + }, + { + name: "refused", + packet: []byte{0xee, 0xa1, 0x81, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}, + want: dns.RCode(5), + }, + { + name: "nxdomain", + packet: []byte{0x34, 0xf4, 0x81, 0x83, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01}, + want: dns.RCode(3), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := getRCode(tt.packet) + if got != tt.want { + t.Errorf("got %d; want %d", got, tt.want) + } + }) + } +}