diff --git a/ipn/ipnlocal/peerapi.go b/ipn/ipnlocal/peerapi.go index 6e45ea7df..176880302 100644 --- a/ipn/ipnlocal/peerapi.go +++ b/ipn/ipnlocal/peerapi.go @@ -1054,6 +1054,9 @@ func writePrettyDNSReply(w io.Writer, res []byte) (err error) { return err } if h.Class != dnsmessage.ClassINET { + if err := p.SkipAnswer(); err != nil { + return err + } continue } switch h.Type { @@ -1075,6 +1078,10 @@ func writePrettyDNSReply(w io.Writer, res []byte) (err error) { return err } gotIPs = append(gotIPs, r.TXT...) + default: + if err := p.SkipAnswer(); err != nil { + return err + } } } j, _ := json.Marshal(gotIPs) diff --git a/ipn/ipnlocal/peerapi_test.go b/ipn/ipnlocal/peerapi_test.go index a5f057bca..2f495b7e7 100644 --- a/ipn/ipnlocal/peerapi_test.go +++ b/ipn/ipnlocal/peerapi_test.go @@ -6,6 +6,7 @@ package ipnlocal import ( "bytes" "context" + "encoding/json" "fmt" "io" "io/fs" @@ -685,6 +686,68 @@ func TestPeerAPIReplyToDNSQueries(t *testing.T) { } } +func TestPeerAPIPrettyReplyCNAME(t *testing.T) { + var h peerAPIHandler + h.remoteAddr = netip.MustParseAddrPort("100.150.151.152:12345") + + eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0) + pm := must.Get(newProfileManager(new(mem.Store), t.Logf)) + h.ps = &peerAPIServer{ + b: &LocalBackend{ + e: eng, + pm: pm, + store: pm.Store(), + // configure as an app connector just to enable the API. + appConnector: appc.NewAppConnector(t.Logf, &appctest.RouteCollector{}), + }, + } + + h.ps.resolver = &fakeResolver{build: func(b *dnsmessage.Builder) { + b.CNAMEResource( + dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName("www.example.com."), + Type: dnsmessage.TypeCNAME, + Class: dnsmessage.ClassINET, + TTL: 0, + }, + dnsmessage.CNAMEResource{ + CNAME: dnsmessage.MustNewName("example.com."), + }, + ) + b.AResource( + dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName("example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + TTL: 0, + }, + dnsmessage.AResource{ + A: [4]byte{192, 0, 0, 8}, + }, + ) + }} + f := filter.NewAllowAllForTest(logger.Discard) + h.ps.b.setFilter(f) + + if !h.replyToDNSQueries() { + t.Errorf("unexpectedly deny; wanted to be a DNS server") + } + + w := httptest.NewRecorder() + h.handleDNSQuery(w, httptest.NewRequest("GET", "/dns-query?q=www.example.com.", nil)) + if w.Code != http.StatusOK { + t.Errorf("unexpected status code: %v", w.Code) + } + var addrs []string + json.NewDecoder(w.Body).Decode(&addrs) + if len(addrs) == 0 { + t.Fatalf("no addresses returned") + } + for _, addr := range addrs { + netip.MustParseAddr(addr) + } +} + func TestPeerAPIReplyToDNSQueriesAreObserved(t *testing.T) { ctx := context.Background() var h peerAPIHandler @@ -704,7 +767,19 @@ func TestPeerAPIReplyToDNSQueriesAreObserved(t *testing.T) { h.ps.b.appConnector.UpdateDomains([]string{"example.com"}) h.ps.b.appConnector.Wait(ctx) - h.ps.resolver = &fakeResolver{} + h.ps.resolver = &fakeResolver{build: func(b *dnsmessage.Builder) { + b.AResource( + dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName("example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + TTL: 0, + }, + dnsmessage.AResource{ + A: [4]byte{192, 0, 0, 8}, + }, + ) + }} f := filter.NewAllowAllForTest(logger.Discard) h.ps.b.setFilter(f) @@ -716,7 +791,7 @@ func TestPeerAPIReplyToDNSQueriesAreObserved(t *testing.T) { } w := httptest.NewRecorder() - h.handleDNSQuery(w, httptest.NewRequest("GET", "/dns-query?q=true&t=example.com.", nil)) + h.handleDNSQuery(w, httptest.NewRequest("GET", "/dns-query?q=example.com.", nil)) if w.Code != http.StatusOK { t.Errorf("unexpected status code: %v", w.Code) } @@ -728,22 +803,14 @@ func TestPeerAPIReplyToDNSQueriesAreObserved(t *testing.T) { } } -type fakeResolver struct{} +type fakeResolver struct { + build func(*dnsmessage.Builder) +} func (f *fakeResolver) HandlePeerDNSQuery(ctx context.Context, q []byte, from netip.AddrPort, allowName func(name string) bool) (res []byte, err error) { b := dnsmessage.NewBuilder(nil, dnsmessage.Header{}) b.EnableCompression() b.StartAnswers() - b.AResource( - dnsmessage.ResourceHeader{ - Name: dnsmessage.MustNewName("example.com."), - Type: dnsmessage.TypeA, - Class: dnsmessage.ClassINET, - TTL: 0, - }, - dnsmessage.AResource{ - A: [4]byte{192, 0, 0, 8}, - }, - ) + f.build(&b) return b.Finish() }