diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 6f1c91a60..e1638b3f3 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -727,10 +727,9 @@ func (c *Direct) sendMapRequest(ctx context.Context, maxPolls int, cb func(*netm } }() - sess := newMapSession() + sess := newMapSession(persist.PrivateNodeKey) sess.logf = c.logf sess.vlogf = vlogf - sess.persist = persist sess.machinePubKey = machinePubKey sess.keepSharerAndUserSplit = c.keepSharerAndUserSplit diff --git a/control/controlclient/map.go b/control/controlclient/map.go index 5f3d7a463..06205c640 100644 --- a/control/controlclient/map.go +++ b/control/controlclient/map.go @@ -11,7 +11,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/logger" "tailscale.com/types/netmap" - "tailscale.com/types/persist" + "tailscale.com/types/wgkey" "tailscale.com/wgengine/filter" ) @@ -25,9 +25,9 @@ import ( // one MapRequest). type mapSession struct { // Immutable fields. + privateNodeKey wgkey.Private logf logger.Logf vlogf logger.Logf - persist persist.Persist machinePubKey tailcfg.MachineKey keepSharerAndUserSplit bool // see Options.KeepSharerAndUserSplit @@ -44,8 +44,9 @@ type mapSession struct { netMapBuilding *netmap.NetworkMap } -func newMapSession() *mapSession { +func newMapSession(privateNodeKey wgkey.Private) *mapSession { ms := &mapSession{ + privateNodeKey: privateNodeKey, logf: logger.Discard, vlogf: logger.Discard, lastDNSConfig: new(tailcfg.DNSConfig), @@ -98,8 +99,8 @@ func (ms *mapSession) netmapForResponse(resp *tailcfg.MapResponse) *netmap.Netwo nm := &netmap.NetworkMap{ SelfNode: resp.Node, - NodeKey: tailcfg.NodeKey(ms.persist.PrivateNodeKey.Public()), - PrivateKey: ms.persist.PrivateNodeKey, + NodeKey: tailcfg.NodeKey(ms.privateNodeKey.Public()), + PrivateKey: ms.privateNodeKey, MachineKey: ms.machinePubKey, Expiry: resp.Node.KeyExpiry, Name: resp.Node.Name, diff --git a/control/controlclient/map_test.go b/control/controlclient/map_test.go index 137dd2863..a0abc8989 100644 --- a/control/controlclient/map_test.go +++ b/control/controlclient/map_test.go @@ -12,6 +12,8 @@ import ( "time" "tailscale.com/tailcfg" + "tailscale.com/types/netmap" + "tailscale.com/types/wgkey" ) func TestUndeltaPeers(t *testing.T) { @@ -165,3 +167,93 @@ func formatNodes(nodes []*tailcfg.Node) string { } return sb.String() } + +func newTestMapSession(t *testing.T) *mapSession { + k, err := wgkey.NewPrivate() + if err != nil { + t.Fatal(err) + } + return newMapSession(k) +} + +func TestNetmapForResponse(t *testing.T) { + t.Run("implicit_packetfilter", func(t *testing.T) { + somePacketFilter := []tailcfg.FilterRule{ + { + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.2.3.4", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + }, + }, + } + ms := newTestMapSession(t) + nm1 := ms.netmapForResponse(&tailcfg.MapResponse{ + Node: new(tailcfg.Node), + PacketFilter: somePacketFilter, + }) + if len(nm1.PacketFilter) == 0 { + t.Fatalf("zero length PacketFilter") + } + nm2 := ms.netmapForResponse(&tailcfg.MapResponse{ + Node: new(tailcfg.Node), + PacketFilter: nil, // testing that the server can omit this. + }) + if len(nm1.PacketFilter) == 0 { + t.Fatalf("zero length PacketFilter in 2nd netmap") + } + if !reflect.DeepEqual(nm1.PacketFilter, nm2.PacketFilter) { + t.Error("packet filters differ") + } + }) + t.Run("implicit_dnsconfig", func(t *testing.T) { + someDNSConfig := &tailcfg.DNSConfig{Domains: []string{"foo", "bar"}} + ms := newTestMapSession(t) + nm1 := ms.netmapForResponse(&tailcfg.MapResponse{ + Node: new(tailcfg.Node), + DNSConfig: someDNSConfig, + }) + if !reflect.DeepEqual(nm1.DNS, *someDNSConfig) { + t.Fatalf("1st DNS wrong") + } + nm2 := ms.netmapForResponse(&tailcfg.MapResponse{ + Node: new(tailcfg.Node), + DNSConfig: nil, // implict + }) + if !reflect.DeepEqual(nm2.DNS, *someDNSConfig) { + t.Fatalf("2nd DNS wrong") + } + }) + t.Run("collect_services", func(t *testing.T) { + ms := newTestMapSession(t) + var nm *netmap.NetworkMap + wantCollect := func(v bool) { + t.Helper() + if nm.CollectServices != v { + t.Errorf("netmap.CollectServices = %v; want %v", nm.CollectServices, v) + } + } + + nm = ms.netmapForResponse(&tailcfg.MapResponse{ + Node: new(tailcfg.Node), + }) + wantCollect(false) + + nm = ms.netmapForResponse(&tailcfg.MapResponse{ + Node: new(tailcfg.Node), + CollectServices: "false", + }) + wantCollect(false) + + nm = ms.netmapForResponse(&tailcfg.MapResponse{ + Node: new(tailcfg.Node), + CollectServices: "true", + }) + wantCollect(true) + + nm = ms.netmapForResponse(&tailcfg.MapResponse{ + Node: new(tailcfg.Node), + CollectServices: "", + }) + wantCollect(true) + }) +}