diff --git a/net/portmapper/portmapper.go b/net/portmapper/portmapper.go index 3bee9e5f9..1105c947d 100644 --- a/net/portmapper/portmapper.go +++ b/net/portmapper/portmapper.go @@ -1015,6 +1015,30 @@ var ( // received a UPnP response from a port other than the UPnP port. metricUPnPResponseAlternatePort = clientmetric.NewCounter("portmap_upnp_response_alternate_port") + // metricUPnPSelectSingle counts the number of times that only a single + // UPnP device was available in selectBestService. + metricUPnPSelectSingle = clientmetric.NewCounter("portmap_upnp_select_single") + + // metricUPnPSelectMultiple counts the number of times that we need to + // select from among multiple UPnP devices in selectBestService. + metricUPnPSelectMultiple = clientmetric.NewCounter("portmap_upnp_select_multiple") + + // metricUPnPSelectExternalPublic counts the number of times that + // selectBestService picked a UPnP device with an external public IP. + metricUPnPSelectExternalPublic = clientmetric.NewCounter("portmap_upnp_select_external_public") + + // metricUPnPSelectExternalPrivate counts the number of times that + // selectBestService picked a UPnP device with an external private IP. + metricUPnPSelectExternalPrivate = clientmetric.NewCounter("portmap_upnp_select_external_private") + + // metricUPnPSelectUp counts the number of times that selectBestService + // picked a UPnP device that was up but with no external IP. + metricUPnPSelectUp = clientmetric.NewCounter("portmap_upnp_select_up") + + // metricUPnPSelectNone counts the number of times that selectBestService + // picked a UPnP device that is not up. + metricUPnPSelectNone = clientmetric.NewCounter("portmap_upnp_select_none") + // metricUPnPParseErr counts the number of times we failed to parse a UPnP response. metricUPnPParseErr = clientmetric.NewCounter("portmap_upnp_parse_err") diff --git a/net/portmapper/select_test.go b/net/portmapper/select_test.go new file mode 100644 index 000000000..9e99c9a9d --- /dev/null +++ b/net/portmapper/select_test.go @@ -0,0 +1,360 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package portmapper + +import ( + "context" + "encoding/xml" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/tailscale/goupnp" + "github.com/tailscale/goupnp/dcps/internetgateway2" +) + +// NOTE: this is in a distinct file because the various string constants are +// pretty verbose. + +func TestSelectBestService(t *testing.T) { + mustParseURL := func(ss string) *url.URL { + u, err := url.Parse(ss) + if err != nil { + t.Fatalf("error parsing URL %q: %v", ss, err) + } + return u + } + + // Run a fake IGD server to respond to UPnP requests. + igd, err := NewTestIGD(t.Logf, TestIGDOptions{UPnP: true}) + if err != nil { + t.Fatal(err) + } + defer igd.Close() + + testCases := []struct { + name string + rootDesc string + control map[string]map[string]any + want string // controlURL field + }{ + { + name: "single_device", + rootDesc: testRootDesc, + control: map[string]map[string]any{ + // Service that's up and should be selected. + "/ctl/IPConn": { + "GetExternalIPAddress": testGetExternalIPAddressResponse, + "GetStatusInfo": testGetStatusInfoResponse, + }, + }, + want: "/ctl/IPConn", + }, + { + name: "first_device_disconnected", + rootDesc: testSelectRootDesc, + control: map[string]map[string]any{ + // Service that's down; it's important that this is the + // one that's down since it's ordered first in the XML + // and we want to verify that our code properly queries + // and then skips it. + "/upnp/control/yomkmsnooi/wanipconn-1": { + "GetStatusInfo": testGetStatusInfoResponseDisconnected, + // NOTE: nothing else should be called + // if GetStatusInfo returns a + // disconnected result + }, + // Service that's up and should be selected. + "/upnp/control/xstnsgeuyh/wanipconn-7": { + "GetExternalIPAddress": testGetExternalIPAddressResponse, + "GetStatusInfo": testGetStatusInfoResponse, + }, + }, + want: "/upnp/control/xstnsgeuyh/wanipconn-7", + }, + { + name: "prefer_public_external_IP", + rootDesc: testSelectRootDesc, + control: map[string]map[string]any{ + // Service with a private external IP; order matters as above. + "/upnp/control/yomkmsnooi/wanipconn-1": { + "GetStatusInfo": testGetStatusInfoResponse, + "GetExternalIPAddress": testGetExternalIPAddressResponsePrivate, + }, + // Service that's up and should be selected. + "/upnp/control/xstnsgeuyh/wanipconn-7": { + "GetExternalIPAddress": testGetExternalIPAddressResponse, + "GetStatusInfo": testGetStatusInfoResponse, + }, + }, + want: "/upnp/control/xstnsgeuyh/wanipconn-7", + }, + { + name: "all_private_external_IPs", + rootDesc: testSelectRootDesc, + control: map[string]map[string]any{ + "/upnp/control/yomkmsnooi/wanipconn-1": { + "GetStatusInfo": testGetStatusInfoResponse, + "GetExternalIPAddress": testGetExternalIPAddressResponsePrivate, + }, + "/upnp/control/xstnsgeuyh/wanipconn-7": { + "GetStatusInfo": testGetStatusInfoResponse, + "GetExternalIPAddress": testGetExternalIPAddressResponsePrivate, + }, + }, + want: "/upnp/control/yomkmsnooi/wanipconn-1", // since this is first in the XML + }, + { + name: "nothing_connected", + rootDesc: testSelectRootDesc, + control: map[string]map[string]any{ + "/upnp/control/yomkmsnooi/wanipconn-1": { + "GetStatusInfo": testGetStatusInfoResponseDisconnected, + }, + "/upnp/control/xstnsgeuyh/wanipconn-7": { + "GetStatusInfo": testGetStatusInfoResponseDisconnected, + }, + }, + want: "/upnp/control/yomkmsnooi/wanipconn-1", // since this is first in the XML + }, + { + name: "GetStatusInfo_errors", + rootDesc: testSelectRootDesc, + control: map[string]map[string]any{ + "/upnp/control/yomkmsnooi/wanipconn-1": { + "GetStatusInfo": func(_ string) (int, string) { + return http.StatusInternalServerError, "internal error" + }, + }, + "/upnp/control/xstnsgeuyh/wanipconn-7": { + "GetStatusInfo": func(_ string) (int, string) { + return http.StatusNotFound, "not found" + }, + }, + }, + want: "/upnp/control/yomkmsnooi/wanipconn-1", // since this is first in the XML + }, + { + name: "GetExternalIPAddress_bad_ip", + rootDesc: testSelectRootDesc, + control: map[string]map[string]any{ + "/upnp/control/yomkmsnooi/wanipconn-1": { + "GetStatusInfo": testGetStatusInfoResponse, + "GetExternalIPAddress": testGetExternalIPAddressResponseInvalid, + }, + "/upnp/control/xstnsgeuyh/wanipconn-7": { + "GetStatusInfo": testGetStatusInfoResponse, + "GetExternalIPAddress": testGetExternalIPAddressResponse, + }, + }, + want: "/upnp/control/xstnsgeuyh/wanipconn-7", + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + // Ensure that we're using our test IGD server for all requests. + rootDesc := strings.ReplaceAll(tt.rootDesc, "@SERVERURL@", igd.ts.URL) + + igd.SetUPnPHandler(&upnpServer{ + t: t, + Desc: rootDesc, + Control: tt.control, + }) + c := newTestClient(t, igd) + t.Logf("Listening on upnp=%v", c.testUPnPPort) + defer c.Close() + + // Ensure that we're using the HTTP client that talks to our test IGD server + ctx := context.Background() + ctx = goupnp.WithHTTPClient(ctx, c.upnpHTTPClientLocked()) + + loc := mustParseURL(igd.ts.URL) + rootDev := mustParseRootDev(t, rootDesc, loc) + + svc, err := selectBestService(ctx, t.Logf, rootDev, loc) + if err != nil { + t.Fatal(err) + } + + var controlURL string + switch v := svc.(type) { + case *internetgateway2.WANIPConnection2: + controlURL = v.ServiceClient.Service.ControlURL.Str + case *internetgateway2.WANIPConnection1: + controlURL = v.ServiceClient.Service.ControlURL.Str + case *internetgateway2.WANPPPConnection1: + controlURL = v.ServiceClient.Service.ControlURL.Str + default: + t.Fatalf("unknown client type: %T", v) + } + + if controlURL != tt.want { + t.Errorf("mismatched controlURL: got=%q want=%q", controlURL, tt.want) + } + }) + } +} + +func mustParseRootDev(t *testing.T, devXML string, loc *url.URL) *goupnp.RootDevice { + decoder := xml.NewDecoder(strings.NewReader(devXML)) + decoder.DefaultSpace = goupnp.DeviceXMLNamespace + decoder.CharsetReader = goupnp.CharsetReaderDefault + + root := new(goupnp.RootDevice) + if err := decoder.Decode(root); err != nil { + t.Fatalf("error decoding device XML: %v", err) + } + + // Ensure the URLBase is set properly; this is how DeviceByURL does it. + var urlBaseStr string + if root.URLBaseStr != "" { + urlBaseStr = root.URLBaseStr + } else { + urlBaseStr = loc.String() + } + urlBase, err := url.Parse(urlBaseStr) + if err != nil { + t.Fatalf("error parsing URL %q: %v", urlBaseStr, err) + } + root.SetURLBase(urlBase) + + return root +} + +// Note: adapted from mikrotikRootDescXML with addresses replaced with +// localhost, and unnecessary fields removed. +const testSelectRootDesc = ` + + + 1 + 0 + + + urn:schemas-upnp-org:device:InternetGatewayDevice:1 + MikroTik Router + MikroTik + https://www.mikrotik.com/ + Router OS + uuid:UUID-MIKROTIK-INTERNET-GATEWAY-DEVICE- + + + urn:schemas-microsoft-com:service:OSInfo:1 + urn:microsoft-com:serviceId:OSInfo1 + /osinfo.xml + /upnp/control/oqjsxqshhz/osinfo + /upnp/event/cwzcyndrjf/osinfo + + + + + urn:schemas-upnp-org:device:WANDevice:1 + WAN Device + MikroTik + https://www.mikrotik.com/ + Router OS + uuid:UUID-MIKROTIK-WAN-DEVICE--1 + + + urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1 + urn:upnp-org:serviceId:WANCommonIFC1 + /wancommonifc-1.xml + /upnp/control/ivvmxhunyq/wancommonifc-1 + /upnp/event/mkjzdqvryf/wancommonifc-1 + + + + + urn:schemas-upnp-org:device:WANConnectionDevice:1 + WAN Connection Device + MikroTik + https://www.mikrotik.com/ + Router OS + uuid:UUID-MIKROTIK-WAN-CONNECTION-DEVICE--1 + + + urn:schemas-upnp-org:service:WANIPConnection:1 + urn:upnp-org:serviceId:WANIPConn1 + /wanipconn-1.xml + /upnp/control/yomkmsnooi/wanipconn-1 + /upnp/event/veeabhzzva/wanipconn-1 + + + + + + + urn:schemas-upnp-org:device:WANDevice:1 + WAN Device + MikroTik + https://www.mikrotik.com/ + Router OS + uuid:UUID-MIKROTIK-WAN-DEVICE--7 + + + urn:schemas-upnp-org:service:WANCommonInterfaceConfig:1 + urn:upnp-org:serviceId:WANCommonIFC1 + /wancommonifc-7.xml + /upnp/control/vzcyyzzttz/wancommonifc-7 + /upnp/event/womwbqtbkq/wancommonifc-7 + + + + + urn:schemas-upnp-org:device:WANConnectionDevice:1 + WAN Connection Device + MikroTik + https://www.mikrotik.com/ + Router OS + uuid:UUID-MIKROTIK-WAN-CONNECTION-DEVICE--7 + + + urn:schemas-upnp-org:service:WANIPConnection:1 + urn:upnp-org:serviceId:WANIPConn1 + /wanipconn-7.xml + /upnp/control/xstnsgeuyh/wanipconn-7 + /upnp/event/rscixkusbs/wanipconn-7 + + + + + + + @SERVERURL@ + + @SERVERURL@ +` + +const testGetStatusInfoResponseDisconnected = ` + + + + Disconnected + ERROR_NONE + 0 + + + +` + +const testGetExternalIPAddressResponsePrivate = ` + + + + 10.9.8.7 + + + +` + +const testGetExternalIPAddressResponseInvalid = ` + + + + not-an-ip-addr + + + +` diff --git a/net/portmapper/upnp.go b/net/portmapper/upnp.go index 31650a516..bc705dc20 100644 --- a/net/portmapper/upnp.go +++ b/net/portmapper/upnp.go @@ -29,6 +29,7 @@ import ( "tailscale.com/envknob" "tailscale.com/net/netns" "tailscale.com/types/logger" + "tailscale.com/util/mak" ) // References: @@ -44,7 +45,14 @@ type upnpMapping struct { goodUntil time.Time renewAfter time.Time - // client is a connection to a upnp device, and may be reused across different UPnP mappings. + // rootDev is the UPnP root device, and may be reused across different + // UPnP mappings. + rootDev *goupnp.RootDevice + // loc is the location used to fetch the rootDev + loc *url.URL + // client is the most recent UPnP client used, and should only be used + // to release an existing mapping; new mappings should be selected from + // the rootDev on each attempt. client upnpClient } @@ -104,6 +112,7 @@ type upnpClient interface { DeletePortMapping(ctx context.Context, remoteHost string, externalPort uint16, protocol string) error GetExternalIPAddress(ctx context.Context) (externalIPAddress string, err error) + GetStatusInfo(ctx context.Context) (status string, lastConnError string, uptime uint32, err error) } // tsPortMappingDesc gets sent to UPnP clients as a human-readable label for the portmapping. @@ -182,24 +191,21 @@ func addAnyPortMapping( return externalPort, err } -// getUPnPClient gets a client for interfacing with UPnP, ignoring the underlying protocol for -// now. +// getUPnPRootDevice fetches the UPnP root device given the discovery response, +// ignoring the underlying protocol for now. // Adapted from https://github.com/huin/goupnp/blob/master/GUIDE.md. // // The gw is the detected gateway. // // The meta is the most recently parsed UDP discovery packet response // from the Internet Gateway Device. -// -// The provided ctx is not retained in the returned upnpClient, but -// its associated HTTP client is (if set via goupnp.WithHTTPClient). -func getUPnPClient(ctx context.Context, logf logger.Logf, debug DebugKnobs, gw netip.Addr, meta uPnPDiscoResponse) (client upnpClient, err error) { +func getUPnPRootDevice(ctx context.Context, logf logger.Logf, debug DebugKnobs, gw netip.Addr, meta uPnPDiscoResponse) (rootDev *goupnp.RootDevice, loc *url.URL, err error) { if debug.DisableUPnP { - return nil, nil + return nil, nil, nil } if meta.Location == "" { - return nil, nil + return nil, nil, nil } if debug.VerboseLogs { @@ -207,12 +213,12 @@ func getUPnPClient(ctx context.Context, logf logger.Logf, debug DebugKnobs, gw n } u, err := url.Parse(meta.Location) if err != nil { - return nil, err + return nil, nil, err } ipp, err := netip.ParseAddrPort(u.Host) if err != nil { - return nil, fmt.Errorf("unexpected host %q in %q", u.Host, meta.Location) + return nil, nil, fmt.Errorf("unexpected host %q in %q", u.Host, meta.Location) } if ipp.Addr() != gw { // https://github.com/tailscale/tailscale/issues/5502 @@ -231,30 +237,150 @@ func getUPnPClient(ctx context.Context, logf logger.Logf, debug DebugKnobs, gw n // This part does a network fetch. root, err := goupnp.DeviceByURL(ctx, u) if err != nil { - return nil, err + return nil, nil, err } + return root, u, nil +} +// selectBestService picks the "best" service from the given UPnP root device +// to use to create a port mapping. +// +// loc is the parsed location that was used to fetch the given RootDevice. +// +// The provided ctx is not retained in the returned upnpClient, but +// its associated HTTP client is (if set via goupnp.WithHTTPClient). +func selectBestService(ctx context.Context, logf logger.Logf, root *goupnp.RootDevice, loc *url.URL) (client upnpClient, err error) { + method := "none" defer func() { if client == nil { return } - logf("saw UPnP type %v at %v; %v (%v)", + logf("saw UPnP type %v at %v; %v (%v), method=%s", strings.TrimPrefix(fmt.Sprintf("%T", client), "*internetgateway2."), - meta.Location, root.Device.FriendlyName, root.Device.Manufacturer) + loc, root.Device.FriendlyName, root.Device.Manufacturer, + method) }() - // These parts don't do a network fetch. - // Pick the best service type available. - if cc, _ := internetgateway2.NewWANIPConnection2ClientsFromRootDevice(ctx, root, u); len(cc) > 0 { - return cc[0], nil + // First, get all available clients from the device, and append to our + // list of possible clients. Order matters here; we want to prefer + // WANIPConnection2 over WANIPConnection1 or WANPPPConnection. + wanIP2, _ := internetgateway2.NewWANIPConnection2ClientsFromRootDevice(ctx, root, loc) + wanIP1, _ := internetgateway2.NewWANIPConnection1ClientsFromRootDevice(ctx, root, loc) + wanPPP, _ := internetgateway2.NewWANPPPConnection1ClientsFromRootDevice(ctx, root, loc) + + var clients []upnpClient + for _, v := range wanIP2 { + clients = append(clients, v) + } + for _, v := range wanIP1 { + clients = append(clients, v) + } + for _, v := range wanPPP { + clients = append(clients, v) + } + + // If we have no clients, then return right now; if we only have one, + // just select and return it. + if len(clients) == 0 { + return nil, nil + } + if len(clients) == 1 { + method = "single" + metricUPnPSelectSingle.Add(1) + return clients[0], nil + } + + metricUPnPSelectMultiple.Add(1) + + // In order to maximize the chances that we find a valid UPnP device + // that can give us a port mapping, we check a few properties: + // 1. Whether the device is "online", as defined by GetStatusInfo + // 2. Whether the device has an external IP address, as defined by + // GetExternalIPAddress + // 3. Whether the device's external IP address is a public address + // or a private one. + // + // We prefer a device where all of the above is true, and fall back if + // none are found. + // + // In order to save on network requests, iterate through all devices + // and determine how many "points" they have based on the above + // criteria, but return immediately if we find one that meets all + // three. + var ( + connected = make(map[upnpClient]bool) + externalIPs map[upnpClient]netip.Addr + ) + for _, svc := range clients { + isConnected := serviceIsConnected(ctx, logf, svc) + connected[svc] = isConnected + + // Don't bother checking for an external IP if the device isn't + // connected; technically this could happen with a misbehaving + // device, but that seems unlikely. + if !isConnected { + continue + } + + // Check if the device has an external IP address. + extIP, err := svc.GetExternalIPAddress(ctx) + if err != nil { + continue + } + externalIP, err := netip.ParseAddr(extIP) + if err != nil { + continue + } + mak.Set(&externalIPs, svc, externalIP) + + // If we get here, this device has a non-private external IP + // and is up, so we can just return it. + if !externalIP.IsPrivate() { + method = "ext-public" + metricUPnPSelectExternalPublic.Add(1) + return svc, nil + } } - if cc, _ := internetgateway2.NewWANIPConnection1ClientsFromRootDevice(ctx, root, u); len(cc) > 0 { - return cc[0], nil + + // Okay, we have no devices that meet all the available options. Fall + // back to first checking for devices that are up and have a private + // external IP (order matters), and then devices that are up, and then + // just anything at all. + // + // try=0 Up + private external IP + // try=1 Up + for try := 0; try <= 1; try++ { + for _, svc := range clients { + if !connected[svc] { + continue + } + _, hasExtIP := externalIPs[svc] + if hasExtIP { + method = "ext-private" + metricUPnPSelectExternalPrivate.Add(1) + return svc, nil + } else if try == 1 { + method = "up" + metricUPnPSelectUp.Add(1) + return svc, nil + } + } } - if cc, _ := internetgateway2.NewWANPPPConnection1ClientsFromRootDevice(ctx, root, u); len(cc) > 0 { - return cc[0], nil + + // Nothing is up, but we have something (length of clients checked + // above); just return the first one. + metricUPnPSelectNone.Add(1) + return clients[0], nil +} + +// serviceIsConnected returns whether a given UPnP service is connected, based +// on the NewConnectionStatus field returned from GetStatusInfo. +func serviceIsConnected(ctx context.Context, logf logger.Logf, svc upnpClient) bool { + status, _ /* NewLastConnectionError */, _ /* NewUptime */, err := svc.GetStatusInfo(ctx) + if err != nil { + return false } - return nil, nil + return status == "Connected" || status == "Up" } func (c *Client) upnpHTTPClientLocked() *http.Client { @@ -295,26 +421,37 @@ func (c *Client) getUPnPPortMapping( internal: internal, } - var client upnpClient - var err error + var ( + rootDev *goupnp.RootDevice + loc *url.URL + err error + ) c.mu.Lock() oldMapping, ok := c.mapping.(*upnpMapping) meta := c.uPnPMeta - httpClient := c.upnpHTTPClientLocked() + ctx = goupnp.WithHTTPClient(ctx, c.upnpHTTPClientLocked()) c.mu.Unlock() if ok && oldMapping != nil { - client = oldMapping.client + rootDev = oldMapping.rootDev + loc = oldMapping.loc } else { - ctx := goupnp.WithHTTPClient(ctx, httpClient) - client, err = getUPnPClient(ctx, c.logf, c.debug, gw, meta) + rootDev, loc, err = getUPnPRootDevice(ctx, c.logf, c.debug, gw, meta) if c.debug.VerboseLogs { - c.logf("getUPnPClient: %T, %v", client, err) + c.logf("getUPnPRootDevice: loc=%q err=%v", loc, err) } if err != nil { return netip.AddrPort{}, false } } - if client == nil { + if rootDev == nil { + return netip.AddrPort{}, false + } + + // Now that we have a root device, select the best mapping service from + // it. This makes network requests, and can vary from mapping to + // mapping if the upstream device's connection status changes. + client, err := selectBestService(ctx, c.logf, rootDev, loc) + if err != nil { return netip.AddrPort{}, false } @@ -384,6 +521,8 @@ func (c *Client) getUPnPPortMapping( d := time.Duration(pmpMapLifetimeSec) * time.Second upnp.goodUntil = now.Add(d) upnp.renewAfter = now.Add(d / 2) + upnp.rootDev = rootDev + upnp.loc = loc upnp.client = client c.mu.Lock() defer c.mu.Unlock() @@ -471,7 +610,7 @@ func requestLogger(logf logger.Logf, client *http.Client) *http.Client { resp, err := oldTransport.RoundTrip(req) if err != nil { - logf("response[%d]: err=%v", err) + logf("response[%d]: err=%v", ctr, err) return nil, err } @@ -480,7 +619,7 @@ func requestLogger(logf logger.Logf, client *http.Client) *http.Client { body, err = io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - logf("response[%d]: %d bodyErr=%v", resp.StatusCode, err) + logf("response[%d]: %d bodyErr=%v", ctr, resp.StatusCode, err) return nil, err } resp.Body = io.NopCloser(bytes.NewReader(body)) diff --git a/net/portmapper/upnp_test.go b/net/portmapper/upnp_test.go index d10f6fdb6..ec1446e91 100644 --- a/net/portmapper/upnp_test.go +++ b/net/portmapper/upnp_test.go @@ -218,19 +218,19 @@ func TestGetUPnPClient(t *testing.T) { "google", googleWifiRootDescXML, "*internetgateway2.WANIPConnection2", - "saw UPnP type WANIPConnection2 at http://127.0.0.1:NNN/rootDesc.xml; OnHub (Google)\n", + "saw UPnP type WANIPConnection2 at http://127.0.0.1:NNN/rootDesc.xml; OnHub (Google), method=single\n", }, { "pfsense", pfSenseRootDescXML, "*internetgateway2.WANIPConnection1", - "saw UPnP type WANIPConnection1 at http://127.0.0.1:NNN/rootDesc.xml; FreeBSD router (FreeBSD)\n", + "saw UPnP type WANIPConnection1 at http://127.0.0.1:NNN/rootDesc.xml; FreeBSD router (FreeBSD), method=single\n", }, { "mikrotik", mikrotikRootDescXML, "*internetgateway2.WANIPConnection1", - "saw UPnP type WANIPConnection1 at http://127.0.0.1:NNN/rootDesc.xml; MikroTik Router (MikroTik)\n", + "saw UPnP type WANIPConnection1 at http://127.0.0.1:NNN/rootDesc.xml; MikroTik Router (MikroTik), method=none\n", }, // TODO(bradfitz): find a PPP one in the wild } @@ -246,13 +246,20 @@ func TestGetUPnPClient(t *testing.T) { defer ts.Close() gw, _ := netip.AddrFromSlice(ts.Listener.Addr().(*net.TCPAddr).IP) gw = gw.Unmap() + + ctx := context.Background() + var logBuf tstest.MemLogger - c, err := getUPnPClient(context.Background(), logBuf.Logf, DebugKnobs{}, gw, uPnPDiscoResponse{ + dev, loc, err := getUPnPRootDevice(ctx, logBuf.Logf, DebugKnobs{}, gw, uPnPDiscoResponse{ Location: ts.URL + "/rootDesc.xml", }) if err != nil { t.Fatal(err) } + c, err := selectBestService(ctx, logBuf.Logf, dev, loc) + if err != nil { + t.Fatal(err) + } got := fmt.Sprintf("%T", c) if got != tt.want { t.Errorf("got %v; want %v", got, tt.want) @@ -272,90 +279,53 @@ func TestGetUPnPPortMapping(t *testing.T) { } defer igd.Close() - rootDesc := "" - // This is a very basic fake UPnP server handler. var sawRequestWithLease atomic.Bool - igd.SetUPnPHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - t.Logf("got UPnP request %s %s", r.Method, r.URL.Path) - switch r.URL.Path { - case "/rootDesc.xml": - io.WriteString(w, rootDesc) - case "/ctl/IPConn", "/upnp/control/yomkmsnooi/wanipconn-1": - body, err := io.ReadAll(r.Body) - if err != nil { - t.Errorf("error reading request body: %v", err) - http.Error(w, "bad request", http.StatusBadRequest) - return - } - - // Decode the request type. - var outerRequest struct { - Body struct { - Request struct { - XMLName xml.Name - } `xml:",any"` - Inner string `xml:",innerxml"` - } `xml:"Body"` + handlers := map[string]any{ + "AddPortMapping": func(body []byte) (int, string) { + // Decode a minimal body to determine whether we skip the request or not. + var req struct { + Protocol string `xml:"NewProtocol"` + InternalPort string `xml:"NewInternalPort"` + ExternalPort string `xml:"NewExternalPort"` + InternalClient string `xml:"NewInternalClient"` + LeaseDuration string `xml:"NewLeaseDuration"` } - if err := xml.Unmarshal(body, &outerRequest); err != nil { + if err := xml.Unmarshal(body, &req); err != nil { t.Errorf("bad request: %v", err) - http.Error(w, "bad request", http.StatusBadRequest) - return + return http.StatusBadRequest, "bad request" } - requestType := outerRequest.Body.Request.XMLName.Local - upnpRequest := outerRequest.Body.Inner - t.Logf("UPnP request: %s", requestType) - - switch requestType { - case "AddPortMapping": - // Decode a minimal body to determine whether we skip the request or not. - var req struct { - Protocol string `xml:"NewProtocol"` - InternalPort string `xml:"NewInternalPort"` - ExternalPort string `xml:"NewExternalPort"` - InternalClient string `xml:"NewInternalClient"` - LeaseDuration string `xml:"NewLeaseDuration"` - } - if err := xml.Unmarshal([]byte(upnpRequest), &req); err != nil { - t.Errorf("bad request: %v", err) - http.Error(w, "bad request", http.StatusBadRequest) - return - } - - if req.Protocol != "UDP" { - t.Errorf(`got Protocol=%q, want "UDP"`, req.Protocol) - } - if req.LeaseDuration != "0" { - // Return a fake error to ensure that we fall back to a permanent lease. - io.WriteString(w, testAddPortMappingPermanentLease) - sawRequestWithLease.Store(true) - } else { - // Success! - io.WriteString(w, testAddPortMappingResponse) - } - case "GetExternalIPAddress": - io.WriteString(w, testGetExternalIPAddressResponse) - - case "DeletePortMapping": - // Do nothing for test - - default: - t.Errorf("unhandled UPnP request type %q", requestType) - http.Error(w, "bad request", http.StatusBadRequest) + if req.Protocol != "UDP" { + t.Errorf(`got Protocol=%q, want "UDP"`, req.Protocol) } - default: - t.Logf("ignoring request") - http.NotFound(w, r) - } - })) + if req.LeaseDuration != "0" { + // Return a fake error to ensure that we fall back to a permanent lease. + sawRequestWithLease.Store(true) + return http.StatusOK, testAddPortMappingPermanentLease + } + + // Success! + return http.StatusOK, testAddPortMappingResponse + }, + "GetExternalIPAddress": testGetExternalIPAddressResponse, + "GetStatusInfo": testGetStatusInfoResponse, + "DeletePortMapping": "", // Do nothing for test + } ctx := context.Background() rootDescsToTest := []string{testRootDesc, mikrotikRootDescXML} + for _, rootDesc := range rootDescsToTest { + igd.SetUPnPHandler(&upnpServer{ + t: t, + Desc: rootDesc, + Control: map[string]map[string]any{ + "/ctl/IPConn": handlers, + "/upnp/control/yomkmsnooi/wanipconn-1": handlers, + }, + }) - for _, rootDesc = range rootDescsToTest { c := newTestClient(t, igd) t.Logf("Listening on upnp=%v", c.testUPnPPort) defer c.Close() @@ -391,6 +361,89 @@ func TestGetUPnPPortMapping(t *testing.T) { } } +type upnpServer struct { + t *testing.T + Desc string // root device XML + Control map[string]map[string]any // map["/url"]map["UPnPService"]response +} + +func (u *upnpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + u.t.Logf("got UPnP request %s %s", r.Method, r.URL.Path) + if r.URL.Path == "/rootDesc.xml" { + io.WriteString(w, u.Desc) + return + } + if control, ok := u.Control[r.URL.Path]; ok { + u.handleControl(w, r, control) + return + } + + u.t.Logf("ignoring request") + http.NotFound(w, r) +} + +func (u *upnpServer) handleControl(w http.ResponseWriter, r *http.Request, handlers map[string]any) { + body, err := io.ReadAll(r.Body) + if err != nil { + u.t.Errorf("error reading request body: %v", err) + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + // Decode the request type. + var outerRequest struct { + Body struct { + Request struct { + XMLName xml.Name + } `xml:",any"` + Inner string `xml:",innerxml"` + } `xml:"Body"` + } + if err := xml.Unmarshal(body, &outerRequest); err != nil { + u.t.Errorf("bad request: %v", err) + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + requestType := outerRequest.Body.Request.XMLName.Local + upnpRequest := outerRequest.Body.Inner + u.t.Logf("UPnP request: %s", requestType) + + handler, ok := handlers[requestType] + if !ok { + u.t.Errorf("unhandled UPnP request type %q", requestType) + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + switch v := handler.(type) { + case string: + io.WriteString(w, v) + case []byte: + w.Write(v) + + // Function handlers + case func(string) string: + io.WriteString(w, v(upnpRequest)) + case func([]byte) string: + io.WriteString(w, v([]byte(upnpRequest))) + + case func(string) (int, string): + code, body := v(upnpRequest) + w.WriteHeader(code) + io.WriteString(w, body) + case func([]byte) (int, string): + code, body := v([]byte(upnpRequest)) + w.WriteHeader(code) + io.WriteString(w, body) + + default: + u.t.Fatalf("invalid handler type: %T", v) + http.Error(w, "invalid handler type", http.StatusInternalServerError) + return + } +} + const testRootDesc = ` @@ -486,3 +539,15 @@ const testGetExternalIPAddressResponse = ` ` + +const testGetStatusInfoResponse = ` + + + + Connected + ERROR_NONE + 9999 + + + +`