diff --git a/tstest/natlab/vnet/vnet.go b/tstest/natlab/vnet/vnet.go index 54f78ab3e..725b71789 100644 --- a/tstest/natlab/vnet/vnet.go +++ b/tstest/natlab/vnet/vnet.go @@ -1286,6 +1286,10 @@ func (n *network) handleUDPPacketForRouter(ep EthernetPacket, udp *layers.UDP, t srcIP, dstIP := flow.src, flow.dst if isDHCPRequest(packet) { + if !n.v4 { + n.logf("dropping DHCPv4 packet on v6-only network") + return + } res, err := n.s.createDHCPResponse(packet) if err != nil { n.logf("createDHCPResponse: %v", err) @@ -1587,6 +1591,7 @@ func (s *Server) createDHCPResponse(request gopacket.Packet) ([]byte, error) { return mkPacketErr(eth, ip, udp, response) } +// isDHCPRequest reports whether pkt is a DHCPv4 request. func isDHCPRequest(pkt gopacket.Packet) bool { v4, ok := pkt.Layer(layers.LayerTypeIPv4).(*layers.IPv4) if !ok || v4.Protocol != layers.IPProtocolUDP { diff --git a/tstest/natlab/vnet/vnet_test.go b/tstest/natlab/vnet/vnet_test.go index 470ec21df..90657b885 100644 --- a/tstest/natlab/vnet/vnet_test.go +++ b/tstest/natlab/vnet/vnet_test.go @@ -98,7 +98,25 @@ func TestPacketSideEffects(t *testing.T) { logSubstr("dropping IPv6 packet on v4-only network"), ), }, - // TODO(bradfitz): DHCP request + response + { + name: "dhcp-discover", + pkt: mkDHCP(nodeMac(1), layers.DHCPMsgTypeDiscover), + check: all( + numPkts(2), // DHCP discover broadcast to node2 also, and the DHCP reply from router + pktSubstr("SrcMAC=52:cc:cc:cc:cc:01 DstMAC=ff:ff:ff:ff:ff:ff"), + pktSubstr("Options=[Option(ServerID:192.168.0.1), Option(MessageType:Offer)]}"), + ), + }, + { + name: "dhcp-request", + pkt: mkDHCP(nodeMac(1), layers.DHCPMsgTypeRequest), + check: all( + numPkts(2), // DHCP discover broadcast to node2 also, and the DHCP reply from router + pktSubstr("SrcMAC=52:cc:cc:cc:cc:01 DstMAC=ff:ff:ff:ff:ff:ff"), + pktSubstr("YourClientIP=192.168.0.101"), + pktSubstr("Options=[Option(ServerID:192.168.0.1), Option(MessageType:Ack), Option(LeaseTime:3600), Option(Router:[192 168 0 1]), Option(DNS:[4 11 4 11]), Option(SubnetMask:255.255.255.0)]}"), + ), + }, }, }, { @@ -132,6 +150,24 @@ func TestPacketSideEffects(t *testing.T) { pktSubstr("TypeCode=EchoRequest"), ), }, + { + name: "no-dhcp-on-v6-disco", + pkt: mkDHCP(nodeMac(1), layers.DHCPMsgTypeDiscover), + check: all( + numPkts(1), // DHCP discover broadcast to node2 only + logSubstr("dropping DHCPv4 packet on v6-only network"), + pktSubstr("SrcMAC=52:cc:cc:cc:cc:01 DstMAC=ff:ff:ff:ff:ff:ff"), + ), + }, + { + name: "no-dhcp-on-v6-request", + pkt: mkDHCP(nodeMac(1), layers.DHCPMsgTypeRequest), + check: all( + numPkts(1), // DHCP request broadcast to node2 only + pktSubstr("SrcMAC=52:cc:cc:cc:cc:01 DstMAC=ff:ff:ff:ff:ff:ff"), + logSubstr("dropping DHCPv4 packet on v6-only network"), + ), + }, }, }, } @@ -145,20 +181,22 @@ func TestPacketSideEffects(t *testing.T) { for _, tt := range tt.tests { t.Run(tt.name, func(t *testing.T) { - se := &sideEffects{} - s.SetLoggerForTest(se.logf) - for mac := range s.MACs() { - s.RegisterSinkForTest(mac, func(eth []byte) { - se.got = append(se.got, eth) - }) - } + se := newSideEffects(s) if err := s.handleEthernetFrameFromVM(tt.pkt); err != nil { t.Fatal(err) } if tt.check != nil { if err := tt.check(se); err != nil { - t.Fatal(err) + t.Error(err) + } + } + if t.Failed() { + t.Logf("logs were:\n%s", strings.Join(se.logs, "\n")) + for i, rp := range se.got { + p := gopacket.NewPacket(rp.eth, layers.LayerTypeEthernet, gopacket.Lazy) + got := p.String() + t.Logf("[pkt%d, port %v]:\n%s\n", i, rp.port, got) } } }) @@ -285,11 +323,63 @@ func mkDNSReq(ipVer int) []byte { return mkPacket(eth, ip, udp, dns) } +func mkDHCP(srcMAC MAC, typ layers.DHCPMsgType) []byte { + eth := &layers.Ethernet{ + SrcMAC: srcMAC.HWAddr(), + DstMAC: macBroadcast.HWAddr(), + EthernetType: layers.EthernetTypeIPv4, + } + ip := &layers.IPv4{ + Version: 4, + Protocol: layers.IPProtocolUDP, + SrcIP: net.ParseIP("0.0.0.0"), + DstIP: net.ParseIP("255.255.255.255"), + } + udp := &layers.UDP{ + SrcPort: 68, + DstPort: 67, + } + dhcp := &layers.DHCPv4{ + Operation: layers.DHCPOpRequest, + HardwareType: layers.LinkTypeEthernet, + HardwareLen: 6, + Xid: 0, + Secs: 0, + Flags: 0, + ClientHWAddr: srcMAC[:], + Options: []layers.DHCPOption{ + {Type: layers.DHCPOptMessageType, Length: 1, Data: []byte{byte(typ)}}, + }, + } + return mkPacket(eth, ip, udp, dhcp) +} + +// receivedPacket is an ethernet frame that was received during a test. +type receivedPacket struct { + port MAC // MAC address of client that received the packet + eth []byte // ethernet frame; dst MAC might be ff:ff:ff:ff:ff:ff, etc +} + // sideEffects gathers side effects as a result of sending a packet and tests // whether those effects were as desired. type sideEffects struct { logs []string - got [][]byte // ethernet packets received + got []receivedPacket // ethernet packets received +} + +// newSideEffects creates a new sideEffects recorder, registering itself with s. +func newSideEffects(s *Server) *sideEffects { + se := &sideEffects{} + s.SetLoggerForTest(se.logf) + for mac := range s.MACs() { + s.RegisterSinkForTest(mac, func(eth []byte) { + se.got = append(se.got, receivedPacket{ + port: mac, + eth: eth, + }) + }) + } + return se } func (se *sideEffects) logf(format string, args ...any) { @@ -318,7 +408,7 @@ func logSubstr(sub string) func(*sideEffects) error { return nil } } - return fmt.Errorf("expected log substring %q not found; log statements were:\n%s", sub, strings.Join(se.logs, "\n")) + return fmt.Errorf("expected log substring %q not found", sub) } } @@ -327,16 +417,14 @@ func logSubstr(sub string) func(*sideEffects) error { // substring sub. func pktSubstr(sub string) func(*sideEffects) error { return func(se *sideEffects) error { - var pkts bytes.Buffer - for i, pkt := range se.got { - pkt := gopacket.NewPacket(pkt, layers.LayerTypeEthernet, gopacket.Lazy) + for _, pkt := range se.got { + pkt := gopacket.NewPacket(pkt.eth, layers.LayerTypeEthernet, gopacket.Lazy) got := pkt.String() - fmt.Fprintf(&pkts, "[pkt%d]:\n%s\n", i, got) if strings.Contains(got, sub) { return nil } } - return fmt.Errorf("packet summary with substring %q not found; packets were:\n%s", sub, pkts.Bytes()) + return fmt.Errorf("packet summary with substring %q not found", sub) } } @@ -347,13 +435,7 @@ func numPkts(want int) func(*sideEffects) error { if len(se.got) == want { return nil } - var pkts bytes.Buffer - for i, pkt := range se.got { - pkt := gopacket.NewPacket(pkt, layers.LayerTypeEthernet, gopacket.Lazy) - got := pkt.String() - fmt.Fprintf(&pkts, "[pkt%d]:\n%s\n", i, got) - } - return fmt.Errorf("got %d packets, want %d. packets were:\n%s", len(se.got), want, pkts.Bytes()) + return fmt.Errorf("got %d packets, want %d", len(se.got), want) } }