// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package wgcfg import ( "bufio" "bytes" "io" "net" "os" "sort" "strings" "sync" "testing" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" "inet.af/netaddr" "tailscale.com/tailcfg" "tailscale.com/types/wgkey" ) func TestDeviceConfig(t *testing.T) { newPrivateKey := func() (wgkey.Key, wgkey.Private) { t.Helper() pk, err := wgkey.NewPrivate() if err != nil { t.Fatal(err) } return wgkey.Key(pk.Public()), wgkey.Private(pk) } k1, pk1 := newPrivateKey() ip1 := netaddr.MustParseIPPrefix("10.0.0.1/32") k2, pk2 := newPrivateKey() ip2 := netaddr.MustParseIPPrefix("10.0.0.2/32") k3, _ := newPrivateKey() ip3 := netaddr.MustParseIPPrefix("10.0.0.3/32") cfg1 := &Config{ PrivateKey: wgkey.Private(pk1), Peers: []Peer{{ PublicKey: k2, AllowedIPs: []netaddr.IPPrefix{ip2}, }}, } cfg2 := &Config{ PrivateKey: wgkey.Private(pk2), Peers: []Peer{{ PublicKey: k1, AllowedIPs: []netaddr.IPPrefix{ip1}, PersistentKeepalive: 5, }}, } device1 := device.NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device1")) device2 := device.NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device2")) defer device1.Close() defer device2.Close() cmp := func(t *testing.T, d *device.Device, want *Config) { t.Helper() got, err := DeviceConfig(d) if err != nil { t.Fatal(err) } prev := new(Config) gotbuf := new(strings.Builder) err = got.ToUAPI(gotbuf, prev) gotStr := gotbuf.String() if err != nil { t.Errorf("got.ToUAPI(): error: %v", err) return } wantbuf := new(strings.Builder) err = want.ToUAPI(wantbuf, prev) wantStr := wantbuf.String() if err != nil { t.Errorf("want.ToUAPI(): error: %v", err) return } if gotStr != wantStr { buf := new(bytes.Buffer) w := bufio.NewWriter(buf) if err := d.IpcGetOperation(w); err != nil { t.Errorf("on error, could not IpcGetOperation: %v", err) } w.Flush() t.Errorf("config mismatch:\n---- got:\n%s\n---- want:\n%s\n---- uapi:\n%s", gotStr, wantStr, buf.String()) } } t.Run("device1 config", func(t *testing.T) { if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { t.Fatal(err) } cmp(t, device1, cfg1) }) t.Run("device2 config", func(t *testing.T) { if err := ReconfigDevice(device2, cfg2, t.Logf); err != nil { t.Fatal(err) } cmp(t, device2, cfg2) }) // This is only to test that Config and Reconfig are properly synchronized. t.Run("device2 config/reconfig", func(t *testing.T) { var wg sync.WaitGroup wg.Add(2) go func() { ReconfigDevice(device2, cfg2, t.Logf) wg.Done() }() go func() { DeviceConfig(device2) wg.Done() }() wg.Wait() }) t.Run("device1 modify peer", func(t *testing.T) { cfg1.Peers[0].DiscoKey = tailcfg.DiscoKey{1} if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { t.Fatal(err) } cmp(t, device1, cfg1) }) t.Run("device1 replace endpoint", func(t *testing.T) { cfg1.Peers[0].DiscoKey = tailcfg.DiscoKey{2} if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { t.Fatal(err) } cmp(t, device1, cfg1) }) t.Run("device1 add new peer", func(t *testing.T) { cfg1.Peers = append(cfg1.Peers, Peer{ PublicKey: k3, AllowedIPs: []netaddr.IPPrefix{ip3}, }) sort.Slice(cfg1.Peers, func(i, j int) bool { return cfg1.Peers[i].PublicKey.LessThan(&cfg1.Peers[j].PublicKey) }) origCfg, err := DeviceConfig(device1) if err != nil { t.Fatal(err) } if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { t.Fatal(err) } cmp(t, device1, cfg1) newCfg, err := DeviceConfig(device1) if err != nil { t.Fatal(err) } peer0 := func(cfg *Config) Peer { p, ok := cfg.PeerWithKey(k2) if !ok { t.Helper() t.Fatal("failed to look up peer 2") } return p } peersEqual := func(p, q Peer) bool { return p.PublicKey == q.PublicKey && p.DiscoKey == q.DiscoKey && p.PersistentKeepalive == q.PersistentKeepalive && cidrsEqual(p.AllowedIPs, q.AllowedIPs) } if !peersEqual(peer0(origCfg), peer0(newCfg)) { t.Error("reconfig modified old peer") } }) t.Run("device1 remove peer", func(t *testing.T) { removeKey := cfg1.Peers[len(cfg1.Peers)-1].PublicKey cfg1.Peers = cfg1.Peers[:len(cfg1.Peers)-1] if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { t.Fatal(err) } cmp(t, device1, cfg1) newCfg, err := DeviceConfig(device1) if err != nil { t.Fatal(err) } _, ok := newCfg.PeerWithKey(removeKey) if ok { t.Error("reconfig failed to remove peer") } }) } // TODO: replace with a loopback tunnel type nilTun struct { events chan tun.Event closed chan struct{} } func newNilTun() tun.Device { return &nilTun{ events: make(chan tun.Event), closed: make(chan struct{}), } } func (t *nilTun) File() *os.File { return nil } func (t *nilTun) Flush() error { return nil } func (t *nilTun) MTU() (int, error) { return 1420, nil } func (t *nilTun) Name() (string, error) { return "niltun", nil } func (t *nilTun) Events() chan tun.Event { return t.events } func (t *nilTun) Read(data []byte, offset int) (int, error) { <-t.closed return 0, io.EOF } func (t *nilTun) Write(data []byte, offset int) (int, error) { <-t.closed return 0, io.EOF } func (t *nilTun) Close() error { close(t.events) close(t.closed) return nil } // A noopBind is a conn.Bind that does no actual binding work. type noopBind struct{} func (noopBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { return nil, 1, nil } func (noopBind) Close() error { return nil } func (noopBind) SetMark(mark uint32) error { return nil } func (noopBind) Send(b []byte, ep conn.Endpoint) error { return nil } func (noopBind) ParseEndpoint(s string) (conn.Endpoint, error) { return dummyEndpoint(s), nil } // A dummyEndpoint is a string holding the endpoint destination. type dummyEndpoint string func (e dummyEndpoint) ClearSrc() {} func (e dummyEndpoint) SrcToString() string { return "" } func (e dummyEndpoint) DstToString() string { return string(e) } func (e dummyEndpoint) DstToBytes() []byte { return nil } func (e dummyEndpoint) DstIP() net.IP { return nil } func (dummyEndpoint) SrcIP() net.IP { return nil }