From 1d4f9852a7bbb9c5525b71279edf6c1a9e39e81d Mon Sep 17 00:00:00 2001 From: David Anderson Date: Fri, 3 Jul 2020 01:47:06 +0000 Subject: [PATCH] tstest/natlab: correctly handle dual-stacked PacketConns. Adds a test with multiple networks, one of which is v4-only. Signed-off-by: David Anderson --- tstest/natlab/natlab.go | 128 +++++++++++++++++++++++++---------- tstest/natlab/natlab_test.go | 63 +++++++++++------ 2 files changed, 137 insertions(+), 54 deletions(-) diff --git a/tstest/natlab/natlab.go b/tstest/natlab/natlab.go index a377c3bae..1505a37a5 100644 --- a/tstest/natlab/natlab.go +++ b/tstest/natlab/natlab.go @@ -184,27 +184,33 @@ type Machine struct { interfaces []*Interface routes []routeEntry // sorted by longest prefix to shortest - conns map[netaddr.IPPort]*conn + conns4 map[netaddr.IPPort]*conn + conns6 map[netaddr.IPPort]*conn } func (m *Machine) deliverIncomingPacket(p []byte, dst, src netaddr.IPPort) { m.mu.Lock() defer m.mu.Unlock() - // TODO(danderson): check behavior of dual stack sockets - c, ok := m.conns[dst] - if !ok { - dst = netaddr.IPPort{IP: unspecOf(dst.IP), Port: dst.Port} - c, ok = m.conns[dst] + conns := m.conns4 + if dst.IP.Is6() { + conns = m.conns6 + } + possibleDsts := []netaddr.IPPort{ + dst, + netaddr.IPPort{IP: v6unspec, Port: dst.Port}, + netaddr.IPPort{IP: v4unspec, Port: dst.Port}, + } + for _, dst := range possibleDsts { + c, ok := conns[dst] if !ok { - return + continue + } + select { + case c.in <- incomingPacket{src: src, p: p}: + default: + // Queue overflow. Just drop it. } - } - - select { - case c.in <- incomingPacket{src: src, p: p}: - default: - // Queue overflow. Just drop it. } } @@ -284,7 +290,12 @@ func (m *Machine) writePacket(p []byte, dst, src netaddr.IPPort) (n int, err err case src.IP == v4unspec: src.IP = iface.V4() case src.IP == v6unspec: - src.IP = iface.V6() + // v6unspec in Go means "any src, but match address families" + if dst.IP.Is6() { + src.IP = iface.V6() + } else if dst.IP.Is4() { + src.IP = iface.V4() + } default: if !iface.Contains(src.IP) { return 0, fmt.Errorf("can't send to %v with src %v on interface %v", dst.IP, src.IP, iface) @@ -321,59 +332,86 @@ func (m *Machine) hasv6() bool { return false } -func (m *Machine) registerConn(c *conn) error { +func (m *Machine) registerConn4(c *conn) error { + m.mu.Lock() + defer m.mu.Unlock() + if c.ipp.IP.Is6() && c.ipp.IP != v6unspec { + return fmt.Errorf("registerConn4 got IPv6 %s", c.ipp) + } + if _, ok := m.conns4[c.ipp]; ok { + return fmt.Errorf("duplicate conn listening on %v", c.ipp) + } + if m.conns4 == nil { + m.conns4 = map[netaddr.IPPort]*conn{} + } + m.conns4[c.ipp] = c + return nil +} + +func (m *Machine) unregisterConn4(c *conn) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.conns4, c.ipp) +} + +func (m *Machine) registerConn6(c *conn) error { m.mu.Lock() defer m.mu.Unlock() - if _, ok := m.conns[c.ipp]; ok { + if c.ipp.IP.Is4() { + return fmt.Errorf("registerConn6 got IPv4 %s", c.ipp) + } + if _, ok := m.conns6[c.ipp]; ok { return fmt.Errorf("duplicate conn listening on %v", c.ipp) } - if m.conns == nil { - m.conns = map[netaddr.IPPort]*conn{} + if m.conns6 == nil { + m.conns6 = map[netaddr.IPPort]*conn{} } - m.conns[c.ipp] = c + m.conns6[c.ipp] = c return nil } -func (m *Machine) unregisterConn(c *conn) { +func (m *Machine) unregisterConn6(c *conn) { m.mu.Lock() defer m.mu.Unlock() - delete(m.conns, c.ipp) + delete(m.conns6, c.ipp) } func (m *Machine) AddNetwork(n *Network) {} func (m *Machine) ListenPacket(network, address string) (net.PacketConn, error) { // if udp4, udp6, etc... look at address IP vs unspec - var fam uint8 + var ( + fam uint8 + ip netaddr.IP + ) switch network { default: return nil, fmt.Errorf("unsupported network type %q", network) case "udp": + fam = 0 + ip = v6unspec case "udp4": fam = 4 + ip = v4unspec case "udp6": fam = 6 + ip = v6unspec } host, portStr, err := net.SplitHostPort(address) if err != nil { return nil, err } - if host == "" { - if m.hasv6() { - host = "::" - } else { - host = "0.0.0.0" + if host != "" { + ip, err = netaddr.ParseIP(host) + if err != nil { + return nil, err } } port, err := strconv.ParseUint(portStr, 10, 16) if err != nil { return nil, err } - ip, err := netaddr.ParseIP(host) - if err != nil { - return nil, err - } ipp := netaddr.IPPort{IP: ip, Port: uint16(port)} c := &conn{ @@ -382,8 +420,22 @@ func (m *Machine) ListenPacket(network, address string) (net.PacketConn, error) ipp: ipp, in: make(chan incomingPacket, 100), // arbitrary } - if err := m.registerConn(c); err != nil { - return nil, err + switch c.fam { + case 0: + if err := m.registerConn4(c); err != nil { + return nil, err + } + if err := m.registerConn6(c); err != nil { + return nil, err + } + case 4: + if err := m.registerConn4(c); err != nil { + return nil, err + } + case 6: + if err := m.registerConn6(c); err != nil { + return nil, err + } } return c, nil } @@ -437,7 +489,15 @@ func (c *conn) Close() error { return nil } c.closed = true - c.m.unregisterConn(c) + switch c.fam { + case 0: + c.m.unregisterConn4(c) + c.m.unregisterConn6(c) + case 4: + c.m.unregisterConn4(c) + case 6: + c.m.unregisterConn6(c) + } c.breakActiveReadsLocked() return nil } diff --git a/tstest/natlab/natlab_test.go b/tstest/natlab/natlab_test.go index d06653b3e..131851282 100644 --- a/tstest/natlab/natlab_test.go +++ b/tstest/natlab/natlab_test.go @@ -77,45 +77,68 @@ func TestSendPacket(t *testing.T) { } } -func TestLAN(t *testing.T) { - // TODO: very duplicate-ey with the previous test, but important - // right now to test explicit construction of Networks. +func TestMultiNetwork(t *testing.T) { lan := Network{ - Name: "lan1", + Name: "lan", Prefix4: mustPrefix("192.168.0.0/24"), } + internet := NewInternet() - foo := NewMachine("foo") - bar := NewMachine("bar") - ifFoo := foo.Attach("eth0", &lan) - ifBar := bar.Attach("eth0", &lan) + client := NewMachine("client") + nat := NewMachine("nat") + server := NewMachine("server") + + ifClient := client.Attach("eth0", &lan) + ifNATWAN := nat.Attach("ethwan", internet) + ifNATLAN := nat.Attach("ethlan", &lan) + ifServer := server.Attach("eth0", internet) - fooPC, err := foo.ListenPacket("udp4", ":123") + clientPC, err := client.ListenPacket("udp", ":123") + if err != nil { + t.Fatal(err) + } + natPC, err := nat.ListenPacket("udp", ":456") if err != nil { t.Fatal(err) } - barPC, err := bar.ListenPacket("udp4", ":456") + serverPC, err := server.ListenPacket("udp", ":789") if err != nil { t.Fatal(err) } - const msg = "message" - barAddr := netaddr.IPPort{IP: ifBar.V4(), Port: 456} - if _, err := fooPC.WriteTo([]byte(msg), barAddr.UDPAddr()); err != nil { + clientAddr := netaddr.IPPort{IP: ifClient.V4(), Port: 123} + natLANAddr := netaddr.IPPort{IP: ifNATLAN.V4(), Port: 456} + natWANAddr := netaddr.IPPort{IP: ifNATWAN.V4(), Port: 456} + serverAddr := netaddr.IPPort{IP: ifServer.V4(), Port: 789} + + const msg1, msg2 = "hello", "world" + if _, err := natPC.WriteTo([]byte(msg1), clientAddr.UDPAddr()); err != nil { + t.Fatal(err) + } + if _, err := natPC.WriteTo([]byte(msg2), serverAddr.UDPAddr()); err != nil { t.Fatal(err) } buf := make([]byte, 1500) - n, addr, err := barPC.ReadFrom(buf) + n, addr, err := clientPC.ReadFrom(buf) if err != nil { t.Fatal(err) } - buf = buf[:n] - if string(buf) != msg { - t.Errorf("read %q; want %q", buf, msg) + if string(buf[:n]) != msg1 { + t.Errorf("read %q; want %q", buf[:n], msg1) } - fooAddr := netaddr.IPPort{IP: ifFoo.V4(), Port: 123} - if addr.String() != fooAddr.String() { - t.Errorf("addr = %q; want %q", addr, fooAddr) + if addr.String() != natLANAddr.String() { + t.Errorf("addr = %q; want %q", addr, natLANAddr) + } + + n, addr, err = serverPC.ReadFrom(buf) + if err != nil { + t.Fatal(err) + } + if string(buf[:n]) != msg2 { + t.Errorf("read %q; want %q", buf[:n], msg2) + } + if addr.String() != natWANAddr.String() { + t.Errorf("addr = %q; want %q", addr, natLANAddr) } }