tstest/natlab: correctly handle dual-stacked PacketConns.

Adds a test with multiple networks, one of which is v4-only.

Signed-off-by: David Anderson <danderson@tailscale.com>
reviewable/pr519/r2
David Anderson 4 years ago committed by Dave Anderson
parent 771eb05bcb
commit 1d4f9852a7

@ -184,29 +184,35 @@ type Machine struct {
interfaces []*Interface interfaces []*Interface
routes []routeEntry // sorted by longest prefix to shortest routes []routeEntry // sorted by longest prefix to shortest
conns map[netaddr.IPPort]*conn conns4 map[netaddr.IPPort]*conn
conns6 map[netaddr.IPPort]*conn
} }
func (m *Machine) deliverIncomingPacket(p []byte, dst, src netaddr.IPPort) { func (m *Machine) deliverIncomingPacket(p []byte, dst, src netaddr.IPPort) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
// TODO(danderson): check behavior of dual stack sockets conns := m.conns4
c, ok := m.conns[dst] if dst.IP.Is6() {
if !ok { conns = m.conns6
dst = netaddr.IPPort{IP: unspecOf(dst.IP), Port: dst.Port}
c, ok = m.conns[dst]
if !ok {
return
} }
possibleDsts := []netaddr.IPPort{
dst,
netaddr.IPPort{IP: v6unspec, Port: dst.Port},
netaddr.IPPort{IP: v4unspec, Port: dst.Port},
}
for _, dst := range possibleDsts {
c, ok := conns[dst]
if !ok {
continue
} }
select { select {
case c.in <- incomingPacket{src: src, p: p}: case c.in <- incomingPacket{src: src, p: p}:
default: default:
// Queue overflow. Just drop it. // Queue overflow. Just drop it.
} }
} }
}
func unspecOf(ip netaddr.IP) netaddr.IP { func unspecOf(ip netaddr.IP) netaddr.IP {
if ip.Is4() { if ip.Is4() {
@ -284,7 +290,12 @@ func (m *Machine) writePacket(p []byte, dst, src netaddr.IPPort) (n int, err err
case src.IP == v4unspec: case src.IP == v4unspec:
src.IP = iface.V4() src.IP = iface.V4()
case src.IP == v6unspec: case src.IP == v6unspec:
// v6unspec in Go means "any src, but match address families"
if dst.IP.Is6() {
src.IP = iface.V6() src.IP = iface.V6()
} else if dst.IP.Is4() {
src.IP = iface.V4()
}
default: default:
if !iface.Contains(src.IP) { if !iface.Contains(src.IP) {
return 0, fmt.Errorf("can't send to %v with src %v on interface %v", dst.IP, src.IP, iface) return 0, fmt.Errorf("can't send to %v with src %v on interface %v", dst.IP, src.IP, iface)
@ -321,56 +332,83 @@ func (m *Machine) hasv6() bool {
return false return false
} }
func (m *Machine) registerConn(c *conn) error { func (m *Machine) registerConn4(c *conn) error {
m.mu.Lock()
defer m.mu.Unlock()
if c.ipp.IP.Is6() && c.ipp.IP != v6unspec {
return fmt.Errorf("registerConn4 got IPv6 %s", c.ipp)
}
if _, ok := m.conns4[c.ipp]; ok {
return fmt.Errorf("duplicate conn listening on %v", c.ipp)
}
if m.conns4 == nil {
m.conns4 = map[netaddr.IPPort]*conn{}
}
m.conns4[c.ipp] = c
return nil
}
func (m *Machine) unregisterConn4(c *conn) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.conns4, c.ipp)
}
func (m *Machine) registerConn6(c *conn) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
if _, ok := m.conns[c.ipp]; ok { if c.ipp.IP.Is4() {
return fmt.Errorf("registerConn6 got IPv4 %s", c.ipp)
}
if _, ok := m.conns6[c.ipp]; ok {
return fmt.Errorf("duplicate conn listening on %v", c.ipp) return fmt.Errorf("duplicate conn listening on %v", c.ipp)
} }
if m.conns == nil { if m.conns6 == nil {
m.conns = map[netaddr.IPPort]*conn{} m.conns6 = map[netaddr.IPPort]*conn{}
} }
m.conns[c.ipp] = c m.conns6[c.ipp] = c
return nil return nil
} }
func (m *Machine) unregisterConn(c *conn) { func (m *Machine) unregisterConn6(c *conn) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
delete(m.conns, c.ipp) delete(m.conns6, c.ipp)
} }
func (m *Machine) AddNetwork(n *Network) {} func (m *Machine) AddNetwork(n *Network) {}
func (m *Machine) ListenPacket(network, address string) (net.PacketConn, error) { func (m *Machine) ListenPacket(network, address string) (net.PacketConn, error) {
// if udp4, udp6, etc... look at address IP vs unspec // if udp4, udp6, etc... look at address IP vs unspec
var fam uint8 var (
fam uint8
ip netaddr.IP
)
switch network { switch network {
default: default:
return nil, fmt.Errorf("unsupported network type %q", network) return nil, fmt.Errorf("unsupported network type %q", network)
case "udp": case "udp":
fam = 0
ip = v6unspec
case "udp4": case "udp4":
fam = 4 fam = 4
ip = v4unspec
case "udp6": case "udp6":
fam = 6 fam = 6
ip = v6unspec
} }
host, portStr, err := net.SplitHostPort(address) host, portStr, err := net.SplitHostPort(address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if host == "" { if host != "" {
if m.hasv6() { ip, err = netaddr.ParseIP(host)
host = "::"
} else {
host = "0.0.0.0"
}
}
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ip, err := netaddr.ParseIP(host) }
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -382,9 +420,23 @@ func (m *Machine) ListenPacket(network, address string) (net.PacketConn, error)
ipp: ipp, ipp: ipp,
in: make(chan incomingPacket, 100), // arbitrary in: make(chan incomingPacket, 100), // arbitrary
} }
if err := m.registerConn(c); err != nil { switch c.fam {
case 0:
if err := m.registerConn4(c); err != nil {
return nil, err
}
if err := m.registerConn6(c); err != nil {
return nil, err
}
case 4:
if err := m.registerConn4(c); err != nil {
return nil, err return nil, err
} }
case 6:
if err := m.registerConn6(c); err != nil {
return nil, err
}
}
return c, nil return c, nil
} }
@ -437,7 +489,15 @@ func (c *conn) Close() error {
return nil return nil
} }
c.closed = true c.closed = true
c.m.unregisterConn(c) switch c.fam {
case 0:
c.m.unregisterConn4(c)
c.m.unregisterConn6(c)
case 4:
c.m.unregisterConn4(c)
case 6:
c.m.unregisterConn6(c)
}
c.breakActiveReadsLocked() c.breakActiveReadsLocked()
return nil return nil
} }

@ -77,45 +77,68 @@ func TestSendPacket(t *testing.T) {
} }
} }
func TestLAN(t *testing.T) { func TestMultiNetwork(t *testing.T) {
// TODO: very duplicate-ey with the previous test, but important
// right now to test explicit construction of Networks.
lan := Network{ lan := Network{
Name: "lan1", Name: "lan",
Prefix4: mustPrefix("192.168.0.0/24"), Prefix4: mustPrefix("192.168.0.0/24"),
} }
internet := NewInternet()
foo := NewMachine("foo") client := NewMachine("client")
bar := NewMachine("bar") nat := NewMachine("nat")
ifFoo := foo.Attach("eth0", &lan) server := NewMachine("server")
ifBar := bar.Attach("eth0", &lan)
ifClient := client.Attach("eth0", &lan)
ifNATWAN := nat.Attach("ethwan", internet)
ifNATLAN := nat.Attach("ethlan", &lan)
ifServer := server.Attach("eth0", internet)
fooPC, err := foo.ListenPacket("udp4", ":123") clientPC, err := client.ListenPacket("udp", ":123")
if err != nil {
t.Fatal(err)
}
natPC, err := nat.ListenPacket("udp", ":456")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
barPC, err := bar.ListenPacket("udp4", ":456") serverPC, err := server.ListenPacket("udp", ":789")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
const msg = "message" clientAddr := netaddr.IPPort{IP: ifClient.V4(), Port: 123}
barAddr := netaddr.IPPort{IP: ifBar.V4(), Port: 456} natLANAddr := netaddr.IPPort{IP: ifNATLAN.V4(), Port: 456}
if _, err := fooPC.WriteTo([]byte(msg), barAddr.UDPAddr()); err != nil { natWANAddr := netaddr.IPPort{IP: ifNATWAN.V4(), Port: 456}
serverAddr := netaddr.IPPort{IP: ifServer.V4(), Port: 789}
const msg1, msg2 = "hello", "world"
if _, err := natPC.WriteTo([]byte(msg1), clientAddr.UDPAddr()); err != nil {
t.Fatal(err)
}
if _, err := natPC.WriteTo([]byte(msg2), serverAddr.UDPAddr()); err != nil {
t.Fatal(err) t.Fatal(err)
} }
buf := make([]byte, 1500) buf := make([]byte, 1500)
n, addr, err := barPC.ReadFrom(buf) n, addr, err := clientPC.ReadFrom(buf)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
buf = buf[:n] if string(buf[:n]) != msg1 {
if string(buf) != msg { t.Errorf("read %q; want %q", buf[:n], msg1)
t.Errorf("read %q; want %q", buf, msg)
} }
fooAddr := netaddr.IPPort{IP: ifFoo.V4(), Port: 123} if addr.String() != natLANAddr.String() {
if addr.String() != fooAddr.String() { t.Errorf("addr = %q; want %q", addr, natLANAddr)
t.Errorf("addr = %q; want %q", addr, fooAddr) }
n, addr, err = serverPC.ReadFrom(buf)
if err != nil {
t.Fatal(err)
}
if string(buf[:n]) != msg2 {
t.Errorf("read %q; want %q", buf[:n], msg2)
}
if addr.String() != natWANAddr.String() {
t.Errorf("addr = %q; want %q", addr, natLANAddr)
} }
} }

Loading…
Cancel
Save