diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index a25c8e717..3ac544908 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -408,18 +408,32 @@ func (h *Handler) serveBugReport(w http.ResponseWriter, r *http.Request) { } func (h *Handler) serveWhoIs(w http.ResponseWriter, r *http.Request) { + h.serveWhoIsWithBackend(w, r, h.b) +} + +// localBackendWhoIsMethods is the subset of ipn.LocalBackend as needed +// by the localapi WhoIs method. +type localBackendWhoIsMethods interface { + WhoIs(netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) + PeerCaps(netip.Addr) tailcfg.PeerCapMap +} + +func (h *Handler) serveWhoIsWithBackend(w http.ResponseWriter, r *http.Request, b localBackendWhoIsMethods) { if !h.PermitRead { http.Error(w, "whois access denied", http.StatusForbidden) return } - b := h.b var ipp netip.AddrPort if v := r.FormValue("addr"); v != "" { - var err error - ipp, err = netip.ParseAddrPort(v) - if err != nil { - http.Error(w, "invalid 'addr' parameter", 400) - return + if ip, err := netip.ParseAddr(v); err == nil { + ipp = netip.AddrPortFrom(ip, 0) + } else { + var err error + ipp, err = netip.ParseAddrPort(v) + if err != nil { + http.Error(w, "invalid 'addr' parameter", 400) + return + } } } else { http.Error(w, "missing 'addr' parameter", 400) @@ -433,7 +447,9 @@ func (h *Handler) serveWhoIs(w http.ResponseWriter, r *http.Request) { res := &apitype.WhoIsResponse{ Node: n.AsStruct(), // always non-nil per WhoIsResponse contract UserProfile: &u, // always non-nil per WhoIsResponse contract - CapMap: b.PeerCaps(ipp.Addr()), + } + if n.Addresses().Len() > 0 { + res.CapMap = b.PeerCaps(n.Addresses().At(0).Addr()) } j, err := json.MarshalIndent(res, "", "\t") if err != nil { diff --git a/ipn/localapi/localapi_test.go b/ipn/localapi/localapi_test.go index 057da9039..2741dc0ef 100644 --- a/ipn/localapi/localapi_test.go +++ b/ipn/localapi/localapi_test.go @@ -9,11 +9,15 @@ import ( "io" "net/http" "net/http/httptest" + "net/netip" + "net/url" + "strings" "testing" "tailscale.com/client/tailscale/apitype" "tailscale.com/hostinfo" "tailscale.com/ipn/ipnlocal" + "tailscale.com/tailcfg" "tailscale.com/tstest" ) @@ -77,3 +81,68 @@ func TestSetPushDeviceToken(t *testing.T) { t.Errorf("hostinfo.PushDeviceToken=%q, want %q", got, want) } } + +type whoIsBackend struct { + whoIs func(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) + peerCaps map[netip.Addr]tailcfg.PeerCapMap +} + +func (b whoIsBackend) WhoIs(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) { + return b.whoIs(ipp) +} + +func (b whoIsBackend) PeerCaps(ip netip.Addr) tailcfg.PeerCapMap { + return b.peerCaps[ip] +} + +// Tests that the WhoIs handler accepts either IPs or IP:ports. +// +// From https://github.com/tailscale/tailscale/pull/9714 (a PR that is effectively a bug report) +func TestWhoIsJustIP(t *testing.T) { + h := &Handler{ + PermitRead: true, + } + for _, input := range []string{"100.101.102.103", "127.0.0.1:123"} { + rec := httptest.NewRecorder() + t.Run(input, func(t *testing.T) { + b := whoIsBackend{ + whoIs: func(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) { + if !strings.Contains(input, ":") { + want := netip.MustParseAddrPort("100.101.102.103:0") + if ipp != want { + t.Fatalf("backend called with %v; want %v", ipp, want) + } + } + return (&tailcfg.Node{ + ID: 123, + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.101.102.103/32"), + }, + }).View(), + tailcfg.UserProfile{ID: 456, DisplayName: "foo"}, + true + }, + peerCaps: map[netip.Addr]tailcfg.PeerCapMap{ + netip.MustParseAddr("100.101.102.103"): map[tailcfg.PeerCapability][]tailcfg.RawMessage{ + "foo": {`"bar"`}, + }, + }, + } + h.serveWhoIsWithBackend(rec, httptest.NewRequest("GET", "/v0/whois?addr="+url.QueryEscape(input), nil), b) + + var res apitype.WhoIsResponse + if err := json.Unmarshal(rec.Body.Bytes(), &res); err != nil { + t.Fatal(err) + } + if got, want := res.Node.ID, tailcfg.NodeID(123); got != want { + t.Errorf("res.Node.ID=%v, want %v", got, want) + } + if got, want := res.UserProfile.DisplayName, "foo"; got != want { + t.Errorf("res.UserProfile.DisplayName=%q, want %q", got, want) + } + if got, want := len(res.CapMap), 1; got != want { + t.Errorf("capmap size=%v, want %v", got, want) + } + }) + } +}