diff --git a/net/portmapper/igd_test.go b/net/portmapper/igd_test.go index 3268aee6d..7b0ba260f 100644 --- a/net/portmapper/igd_test.go +++ b/net/portmapper/igd_test.go @@ -11,8 +11,10 @@ import ( "net/http" "net/http/httptest" "sync" + "testing" "inet.af/netaddr" + "tailscale.com/types/logger" ) // TestIGD is an IGD (Intenet Gateway Device) for testing. It supports fake @@ -21,15 +23,25 @@ type TestIGD struct { upnpConn net.PacketConn // for UPnP discovery pxpConn net.PacketConn // for NAT-PMP and/or PCP ts *httptest.Server + logf logger.Logf + + // do* will log which packets are sent, but will not reply to unexpected packets. doPMP bool doPCP bool - doUPnP bool // TODO: more options for 3 flavors of UPnP services + doUPnP bool mu sync.Mutex // guards below counters igdCounters } +// TestIGDOptions are options +type TestIGDOptions struct { + PMP bool + PCP bool + UPnP bool // TODO: more options for 3 flavors of UPnP services +} + type igdCounters struct { numUPnPDiscoRecv int32 numUPnPOtherUDPRecv int32 @@ -38,15 +50,21 @@ type igdCounters struct { numPMPDiscoRecv int32 numPCPRecv int32 numPCPDiscoRecv int32 + numPCPMapRecv int32 + numPCPOtherRecv int32 numPMPPublicAddrRecv int32 numPMPBogusRecv int32 + + numFailedWrites int32 + invalidPCPMapPkt int32 } -func NewTestIGD() (*TestIGD, error) { +func NewTestIGD(logf logger.Logf, t TestIGDOptions) (*TestIGD, error) { d := &TestIGD{ - doPMP: true, - doPCP: true, - doUPnP: true, + logf: logf, + doPMP: t.PMP, + doPCP: t.PCP, + doUPnP: t.UPnP, } var err error if d.upnpConn, err = testListenUDP(); err != nil { @@ -74,6 +92,10 @@ func (d *TestIGD) TestUPnPPort() uint16 { return uint16(d.upnpConn.LocalAddr().(*net.UDPAddr).Port) } +func testIPAndGateway() (gw, ip netaddr.IP, ok bool) { + return netaddr.IPv4(127, 0, 0, 1), netaddr.IPv4(1, 2, 3, 4), true +} + func (d *TestIGD) Close() error { d.ts.Close() d.upnpConn.Close() @@ -102,13 +124,19 @@ func (d *TestIGD) serveUPnPDiscovery() { for { n, src, err := d.upnpConn.ReadFrom(buf) if err != nil { + d.logf("serveUPnP failed: %v", err) return } pkt := buf[:n] if bytes.Equal(pkt, uPnPPacket) { // a super lazy "parse" d.inc(&d.counters.numUPnPDiscoRecv) resPkt := []byte(fmt.Sprintf("HTTP/1.1 200 OK\r\nCACHE-CONTROL: max-age=120\r\nST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\nUSN: uuid:bee7052b-49e8-3597-b545-55a1e38ac11::urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\nEXT:\r\nSERVER: Tailscale-Test/1.0 UPnP/1.1 MiniUPnPd/2.2.1\r\nLOCATION: %s\r\nOPT: \"http://schemas.upnp.org/upnp/1/0/\"; ns=01\r\n01-NLS: 1627958564\r\nBOOTID.UPNP.ORG: 1627958564\r\nCONFIGID.UPNP.ORG: 1337\r\n\r\n", d.ts.URL+"/rootDesc.xml")) - d.upnpConn.WriteTo(resPkt, src) + if d.doUPnP { + _, err = d.upnpConn.WriteTo(resPkt, src) + if err != nil { + d.inc(&d.counters.numFailedWrites) + } + } } else { d.inc(&d.counters.numUPnPOtherUDPRecv) } @@ -121,6 +149,7 @@ func (d *TestIGD) servePxP() { for { n, a, err := d.pxpConn.ReadFrom(buf) if err != nil { + d.logf("servePxP failed: %v", err) return } ua := a.(*net.UDPAddr) @@ -164,5 +193,55 @@ func (d *TestIGD) handlePMPQuery(pkt []byte, src netaddr.IPPort) { func (d *TestIGD) handlePCPQuery(pkt []byte, src netaddr.IPPort) { d.inc(&d.counters.numPCPRecv) - // TODO + if len(pkt) < 24 { + return + } + op := pkt[1] + pktSrcBytes := [16]byte{} + copy(pktSrcBytes[:], pkt[8:24]) + pktSrc := netaddr.IPFrom16(pktSrcBytes) + if pktSrc != src.IP() { + // TODO this error isn't fatal but should be rejected by server. + // Since it's a test it's difficult to get them the same though. + d.logf("mismatch of packet source and source IP: got %v, expected %v", pktSrc, src.IP()) + } + switch op { + case pcpOpAnnounce: + d.inc(&d.counters.numPCPDiscoRecv) + if !d.doPCP { + return + } + resp := buildPCPDiscoResponse(pkt) + if _, err := d.pxpConn.WriteTo(resp, src.UDPAddr()); err != nil { + d.inc(&d.counters.numFailedWrites) + } + case pcpOpMap: + if len(pkt) < 60 { + d.logf("got too short packet for pcp op map: %v", pkt) + d.inc(&d.counters.invalidPCPMapPkt) + return + } + d.inc(&d.counters.numPCPMapRecv) + if !d.doPCP { + return + } + resp := buildPCPMapResponse(pkt) + d.pxpConn.WriteTo(resp, src.UDPAddr()) + default: + // unknown op code, ignore it for now. + d.inc(&d.counters.numPCPOtherRecv) + return + } +} + +func newTestClient(t *testing.T, igd *TestIGD) *Client { + var c *Client + c = NewClient(t.Logf, func() { + t.Logf("port map changed") + t.Logf("have mapping: %v", c.HaveMapping()) + }) + c.testPxPPort = igd.TestPxPPort() + c.testUPnPPort = igd.TestUPnPPort() + c.SetGatewayLookupFunc(testIPAndGateway) + return c } diff --git a/net/portmapper/pcp_test.go b/net/portmapper/pcp_test.go index 4bf859d2c..f696795d9 100644 --- a/net/portmapper/pcp_test.go +++ b/net/portmapper/pcp_test.go @@ -5,6 +5,7 @@ package portmapper import ( + "encoding/binary" "testing" "inet.af/netaddr" @@ -25,3 +26,37 @@ func TestParsePCPMapResponse(t *testing.T) { t.Errorf("mismatched external address, got: %v, want: %v", mapping.external, expectedAddr) } } + +const ( + serverResponseBit = 1 << 7 + fakeLifetimeSec = 1<<31 - 1 +) + +func buildPCPDiscoResponse(req []byte) []byte { + out := make([]byte, 24) + out[0] = pcpVersion + out[1] = req[1] | serverResponseBit + out[3] = 0 + // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail. + return out +} + +func buildPCPMapResponse(req []byte) []byte { + out := make([]byte, 24+36) + out[0] = pcpVersion + out[1] = req[1] | serverResponseBit + out[3] = 0 + binary.BigEndian.PutUint32(out[4:8], 1<<30) + // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail. + mapResp := out[24:] + mapReq := req[24:] + // copy nonce, protocol and internal port + copy(mapResp[:13], mapReq[:13]) + copy(mapResp[16:18], mapReq[16:18]) + // assign external port + binary.BigEndian.PutUint16(mapResp[18:20], 4242) + assignedIP := netaddr.IPv4(127, 0, 0, 1) + assignedIP16 := assignedIP.As16() + copy(mapResp[20:36], assignedIP16[:]) + return out +} diff --git a/net/portmapper/portmapper_test.go b/net/portmapper/portmapper_test.go index 4c3026e81..503d03103 100644 --- a/net/portmapper/portmapper_test.go +++ b/net/portmapper/portmapper_test.go @@ -11,9 +11,6 @@ import ( "strconv" "testing" "time" - - "inet.af/netaddr" - "tailscale.com/types/logger" ) func TestCreateOrGetMapping(t *testing.T) { @@ -61,24 +58,15 @@ func TestClientProbeThenMap(t *testing.T) { } func TestProbeIntegration(t *testing.T) { - igd, err := NewTestIGD() + igd, err := NewTestIGD(t.Logf, TestIGDOptions{PMP: true, PCP: true, UPnP: true}) if err != nil { t.Fatal(err) } defer igd.Close() - logf := t.Logf - var c *Client - c = NewClient(logger.WithPrefix(logf, "portmapper: "), func() { - logf("portmapping changed.") - logf("have mapping: %v", c.HaveMapping()) - }) - c.testPxPPort = igd.TestPxPPort() - c.testUPnPPort = igd.TestUPnPPort() + c := newTestClient(t, igd) t.Logf("Listening on pxp=%v, upnp=%v", c.testPxPPort, c.testUPnPPort) - c.SetGatewayLookupFunc(func() (gw, self netaddr.IP, ok bool) { - return netaddr.IPv4(127, 0, 0, 1), netaddr.IPv4(1, 2, 3, 4), true - }) + defer c.Close() res, err := c.Probe(context.Background()) if err != nil { @@ -92,6 +80,7 @@ func TestProbeIntegration(t *testing.T) { numUPnPDiscoRecv: 1, numPMPRecv: 1, numPCPRecv: 1, + numPCPDiscoRecv: 1, numPMPPublicAddrRecv: 1, } if !reflect.DeepEqual(st, want) { @@ -102,3 +91,35 @@ func TestProbeIntegration(t *testing.T) { t.Logf("IGD stats: %+v", st) // TODO(bradfitz): finish } + +func TestPCPIntegration(t *testing.T) { + igd, err := NewTestIGD(t.Logf, TestIGDOptions{PMP: false, PCP: true, UPnP: false}) + if err != nil { + t.Fatal(err) + } + defer igd.Close() + + c := newTestClient(t, igd) + defer c.Close() + res, err := c.Probe(context.Background()) + if err != nil { + t.Fatalf("probe failed: %v", err) + } + if res.UPnP || res.PMP { + t.Errorf("probe unexpectedly saw upnp or pmp: %+v", res) + } + if !res.PCP { + t.Fatalf("probe did not see pcp: %+v", res) + } + + external, err := c.createOrGetMapping(context.Background()) + if err != nil { + t.Fatalf("failed to get mapping: %v", err) + } + if external.IsZero() { + t.Errorf("got zero IP, expected non-zero") + } + if c.mapping == nil { + t.Errorf("got nil mapping after successful createOrGetMapping") + } +}