From 59ba2e63169c7aa54644a22e6533752053dc3023 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Mon, 17 Feb 2020 15:01:23 -0800 Subject: [PATCH] ipn: implement Prefs.Equals efficiently. Signed-off-by: David Anderson --- control/controlclient/direct.go | 15 +++ control/controlclient/persist_test.go | 98 ++++++++++++++++ ipn/prefs.go | 93 +++++++++------ ipn/prefs_test.go | 157 ++++++++++++++++++++++++++ 4 files changed, 329 insertions(+), 34 deletions(-) create mode 100644 control/controlclient/persist_test.go diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 9c1cd1ff6..a91b52dff 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -39,6 +39,21 @@ type Persist struct { LoginName string } +func (p *Persist) Equals(p2 *Persist) bool { + if p == nil && p2 == nil { + return true + } + if p == nil || p2 == nil { + return false + } + + return p.PrivateMachineKey.Equal(p2.PrivateMachineKey) && + p.PrivateNodeKey.Equal(p2.PrivateNodeKey) && + p.OldPrivateNodeKey.Equal(p2.OldPrivateNodeKey) && + p.Provider == p2.Provider && + p.LoginName == p2.LoginName +} + func (p *Persist) Pretty() string { var mk, ok, nk wgcfg.Key if !p.PrivateMachineKey.IsZero() { diff --git a/control/controlclient/persist_test.go b/control/controlclient/persist_test.go new file mode 100644 index 000000000..621da8bbf --- /dev/null +++ b/control/controlclient/persist_test.go @@ -0,0 +1,98 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package controlclient + +import ( + "reflect" + "testing" + + "github.com/tailscale/wireguard-go/wgcfg" +) + +func TestPersistEqual(t *testing.T) { + persistHandles := []string{"PrivateMachineKey", "PrivateNodeKey", "OldPrivateNodeKey", "Provider", "LoginName"} + if have := fieldsOf(reflect.TypeOf(Persist{})); !reflect.DeepEqual(have, persistHandles) { + t.Errorf("Persist.Equal check might be out of sync\nfields: %q\nhandled: %q\n", + have, persistHandles) + } + + newPrivate := func() wgcfg.PrivateKey { + k, err := wgcfg.NewPrivateKey() + if err != nil { + panic(err) + } + return k + } + k1 := newPrivate() + tests := []struct { + a, b *Persist + want bool + }{ + {nil, nil, true}, + {nil, &Persist{}, false}, + {&Persist{}, nil, false}, + {&Persist{}, &Persist{}, true}, + + { + &Persist{PrivateMachineKey: k1}, + &Persist{PrivateMachineKey: newPrivate()}, + false, + }, + { + &Persist{PrivateMachineKey: k1}, + &Persist{PrivateMachineKey: k1}, + true, + }, + + { + &Persist{PrivateNodeKey: k1}, + &Persist{PrivateNodeKey: newPrivate()}, + false, + }, + { + &Persist{PrivateNodeKey: k1}, + &Persist{PrivateNodeKey: k1}, + true, + }, + + { + &Persist{OldPrivateNodeKey: k1}, + &Persist{OldPrivateNodeKey: newPrivate()}, + false, + }, + { + &Persist{OldPrivateNodeKey: k1}, + &Persist{OldPrivateNodeKey: k1}, + true, + }, + + { + &Persist{Provider: "google"}, + &Persist{Provider: "o365"}, + false, + }, + { + &Persist{Provider: "google"}, + &Persist{Provider: "google"}, + true, + }, + + { + &Persist{LoginName: "foo@tailscale.com"}, + &Persist{LoginName: "bar@tailscale.com"}, + false, + }, + { + &Persist{LoginName: "foo@tailscale.com"}, + &Persist{LoginName: "foo@tailscale.com"}, + true, + }, + } + for i, test := range tests { + if got := test.a.Equals(test.b); got != test.want { + t.Errorf("%d. Equals = %v; want %v", i, got, test.want) + } + } +} diff --git a/ipn/prefs.go b/ipn/prefs.go index 6da71bfaa..4ff5a347d 100644 --- a/ipn/prefs.go +++ b/ipn/prefs.go @@ -36,32 +36,57 @@ type Prefs struct { } // IsEmpty reports whether p is nil or pointing to a Prefs zero value. -func (uc *Prefs) IsEmpty() bool { return uc == nil || uc.Equals(&Prefs{}) } +func (p *Prefs) IsEmpty() bool { return p == nil || p.Equals(&Prefs{}) } -func (uc *Prefs) Pretty() string { - var ucp string - if uc.Persist != nil { - ucp = uc.Persist.Pretty() +func (p *Prefs) Pretty() string { + var pp string + if p.Persist != nil { + pp = p.Persist.Pretty() } else { - ucp = "Persist=nil" + pp = "Persist=nil" } return fmt.Sprintf("Prefs{ra=%v mesh=%v dns=%v want=%v notepad=%v pf=%v routes=%v %v}", - uc.RouteAll, uc.AllowSingleHosts, uc.CorpDNS, uc.WantRunning, - uc.NotepadURLs, uc.UsePacketFilter, uc.AdvertiseRoutes, ucp) + p.RouteAll, p.AllowSingleHosts, p.CorpDNS, p.WantRunning, + p.NotepadURLs, p.UsePacketFilter, p.AdvertiseRoutes, pp) } -func (uc *Prefs) ToBytes() []byte { - data, err := json.MarshalIndent(uc, "", "\t") +func (p *Prefs) ToBytes() []byte { + data, err := json.MarshalIndent(p, "", "\t") if err != nil { log.Fatalf("Prefs marshal: %v\n", err) } return data } -func (uc *Prefs) Equals(uc2 *Prefs) bool { - b1 := uc.ToBytes() - b2 := uc2.ToBytes() - return bytes.Equal(b1, b2) +func (p *Prefs) Equals(p2 *Prefs) bool { + if p == nil && p2 == nil { + return true + } + if p == nil || p2 == nil { + return false + } + + return p != nil && p2 != nil && + p.RouteAll == p2.RouteAll && + p.AllowSingleHosts == p2.AllowSingleHosts && + p.CorpDNS == p2.CorpDNS && + p.WantRunning == p2.WantRunning && + p.NotepadURLs == p2.NotepadURLs && + p.UsePacketFilter == p2.UsePacketFilter && + compareIPNets(p.AdvertiseRoutes, p2.AdvertiseRoutes) && + p.Persist.Equals(p2.Persist) +} + +func compareIPNets(a, b []*net.IPNet) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if !a[i].IP.Equal(b[i].IP) || !bytes.Equal(a[i].Mask, b[i].Mask) { + return false + } + } + return true } func NewPrefs() Prefs { @@ -78,45 +103,45 @@ func NewPrefs() Prefs { } func PrefsFromBytes(b []byte, enforceDefaults bool) (Prefs, error) { - uc := NewPrefs() + p := NewPrefs() if len(b) == 0 { - return uc, nil + return p, nil } persist := &controlclient.Persist{} err := json.Unmarshal(b, persist) if err == nil && (persist.Provider != "" || persist.LoginName != "") { // old-style relaynode config; import it - uc.Persist = persist + p.Persist = persist } else { - err = json.Unmarshal(b, &uc) + err = json.Unmarshal(b, &p) if err != nil { log.Printf("Prefs parse: %v: %v\n", err, b) } } if enforceDefaults { - uc.RouteAll = true - uc.AllowSingleHosts = true + p.RouteAll = true + p.AllowSingleHosts = true } - return uc, err + return p, err } -func (uc *Prefs) Copy() *Prefs { - uc2, err := PrefsFromBytes(uc.ToBytes(), false) +func (p *Prefs) Copy() *Prefs { + p2, err := PrefsFromBytes(p.ToBytes(), false) if err != nil { log.Fatalf("Prefs was uncopyable: %v\n", err) } - return &uc2 + return &p2 } func LoadPrefs(filename string, enforceDefaults bool) *Prefs { log.Printf("Loading prefs %v\n", filename) data, err := ioutil.ReadFile(filename) - uc := NewPrefs() + p := NewPrefs() if err != nil { log.Printf("Read: %v: %v\n", filename, err) goto fail } - uc, err = PrefsFromBytes(data, enforceDefaults) + p, err = PrefsFromBytes(data, enforceDefaults) if err != nil { log.Printf("Parse: %v: %v\n", filename, err) goto fail @@ -124,8 +149,8 @@ func LoadPrefs(filename string, enforceDefaults bool) *Prefs { goto post fail: log.Printf("failed to load config. Generating a new one.\n") - uc = NewPrefs() - uc.WantRunning = true + p = NewPrefs() + p.WantRunning = true post: // Update: we changed our minds :) // Versabank would like to persist the setting across reboots, for now, @@ -138,15 +163,15 @@ post: // know how, rebooting will fix it. // We still persist WantRunning just in case we change our minds on // this topic. - uc.WantRunning = true + p.WantRunning = true } - log.Printf("Loaded prefs %v %v\n", filename, uc.Pretty()) - return &uc + log.Printf("Loaded prefs %v %v\n", filename, p.Pretty()) + return &p } -func SavePrefs(filename string, uc *Prefs) { - log.Printf("Saving prefs %v %v\n", filename, uc.Pretty()) - data := uc.ToBytes() +func SavePrefs(filename string, p *Prefs) { + log.Printf("Saving prefs %v %v\n", filename, p.Pretty()) + data := p.ToBytes() os.MkdirAll(filepath.Dir(filename), 0700) if err := atomicfile.WriteFile(filename, data, 0666); err != nil { log.Printf("SavePrefs: %v\n", err) diff --git a/ipn/prefs_test.go b/ipn/prefs_test.go index 95e5cca6c..aedcca750 100644 --- a/ipn/prefs_test.go +++ b/ipn/prefs_test.go @@ -5,11 +5,168 @@ package ipn import ( + "net" + "reflect" "testing" "tailscale.com/control/controlclient" ) +func fieldsOf(t reflect.Type) (fields []string) { + for i := 0; i < t.NumField(); i++ { + fields = append(fields, t.Field(i).Name) + } + return +} + +func TestPrefsEqual(t *testing.T) { + prefsHandles := []string{"RouteAll", "AllowSingleHosts", "CorpDNS", "WantRunning", "NotepadURLs", "UsePacketFilter", "AdvertiseRoutes", "Persist"} + if have := fieldsOf(reflect.TypeOf(Prefs{})); !reflect.DeepEqual(have, prefsHandles) { + t.Errorf("Prefs.Equal check might be out of sync\nfields: %q\nhandled: %q\n", + have, prefsHandles) + } + + nets := func(strs ...string) (ns []*net.IPNet) { + for _, s := range strs { + _, n, err := net.ParseCIDR(s) + if err != nil { + panic(err) + } + ns = append(ns, n) + } + return ns + } + tests := []struct { + a, b *Prefs + want bool + }{ + { + &Prefs{}, + nil, + false, + }, + { + nil, + &Prefs{}, + false, + }, + { + &Prefs{}, + &Prefs{}, + true, + }, + + { + &Prefs{RouteAll: true}, + &Prefs{RouteAll: false}, + false, + }, + { + &Prefs{RouteAll: true}, + &Prefs{RouteAll: true}, + true, + }, + + { + &Prefs{AllowSingleHosts: true}, + &Prefs{AllowSingleHosts: false}, + false, + }, + { + &Prefs{AllowSingleHosts: true}, + &Prefs{AllowSingleHosts: true}, + true, + }, + + { + &Prefs{CorpDNS: true}, + &Prefs{CorpDNS: false}, + false, + }, + { + &Prefs{CorpDNS: true}, + &Prefs{CorpDNS: true}, + true, + }, + + { + &Prefs{WantRunning: true}, + &Prefs{WantRunning: false}, + false, + }, + { + &Prefs{WantRunning: true}, + &Prefs{WantRunning: true}, + true, + }, + + { + &Prefs{NotepadURLs: true}, + &Prefs{NotepadURLs: false}, + false, + }, + { + &Prefs{NotepadURLs: true}, + &Prefs{NotepadURLs: true}, + true, + }, + + { + &Prefs{UsePacketFilter: true}, + &Prefs{UsePacketFilter: false}, + false, + }, + { + &Prefs{UsePacketFilter: true}, + &Prefs{UsePacketFilter: true}, + true, + }, + + { + &Prefs{AdvertiseRoutes: nil}, + &Prefs{AdvertiseRoutes: []*net.IPNet{}}, + true, + }, + { + &Prefs{AdvertiseRoutes: []*net.IPNet{}}, + &Prefs{AdvertiseRoutes: []*net.IPNet{}}, + true, + }, + { + &Prefs{AdvertiseRoutes: nets("192.168.0.0/24", "10.1.0.0/16")}, + &Prefs{AdvertiseRoutes: nets("192.168.1.0/24", "10.2.0.0/16")}, + false, + }, + { + &Prefs{AdvertiseRoutes: nets("192.168.0.0/24", "10.1.0.0/16")}, + &Prefs{AdvertiseRoutes: nets("192.168.0.0/24", "10.2.0.0/16")}, + false, + }, + { + &Prefs{AdvertiseRoutes: nets("192.168.0.0/24", "10.1.0.0/16")}, + &Prefs{AdvertiseRoutes: nets("192.168.0.0/24", "10.1.0.0/16")}, + true, + }, + + { + &Prefs{Persist: &controlclient.Persist{}}, + &Prefs{Persist: &controlclient.Persist{LoginName: "dave"}}, + false, + }, + { + &Prefs{Persist: &controlclient.Persist{LoginName: "dave"}}, + &Prefs{Persist: &controlclient.Persist{LoginName: "dave"}}, + true, + }, + } + for i, tt := range tests { + got := tt.a.Equals(tt.b) + if got != tt.want { + t.Errorf("%d. Equal = %v; want %v", i, got, tt.want) + } + } +} + func checkPrefs(t *testing.T, p Prefs) { var err error var p2, p2c Prefs