diff --git a/types/wgkey/key.go b/types/wgkey/key.go index b9ba6deed..d96363952 100644 --- a/types/wgkey/key.go +++ b/types/wgkey/key.go @@ -27,7 +27,7 @@ import ( const Size = 32 // A Key is a curve25519 key. -// It is used by WireGuard to represent public keys. +// It is used by WireGuard to represent public and preshared keys. type Key [Size]byte // NewPreshared generates a new random Key. diff --git a/wgengine/bench/wg.go b/wgengine/bench/wg.go index acf04f6bd..167975a3f 100644 --- a/wgengine/bench/wg.go +++ b/wgengine/bench/wg.go @@ -27,7 +27,7 @@ import ( func setupWGTest(logf logger.Logf, traf *TrafficGen, a1, a2 netaddr.IPPrefix) { l1 := logger.WithPrefix(logf, "e1: ") - k1, err := wgcfg.NewPrivateKey() + k1, err := wgkey.NewPrivate() if err != nil { log.Fatalf("e1 NewPrivateKey: %v", err) } @@ -51,7 +51,7 @@ func setupWGTest(logf logger.Logf, traf *TrafficGen, a1, a2 netaddr.IPPrefix) { } l2 := logger.WithPrefix(logf, "e2: ") - k2, err := wgcfg.NewPrivateKey() + k2, err := wgkey.NewPrivate() if err != nil { log.Fatalf("e2 NewPrivateKey: %v", err) } diff --git a/wgengine/magicsock/legacy.go b/wgengine/magicsock/legacy.go index b55347054..db59580e3 100644 --- a/wgengine/magicsock/legacy.go +++ b/wgengine/magicsock/legacy.go @@ -27,7 +27,6 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/wgkey" - "tailscale.com/wgengine/wgcfg" ) var ( @@ -591,8 +590,8 @@ func init() { type messageInitiation struct { Type uint32 Sender uint32 - Ephemeral wgcfg.Key - Static [wgcfg.KeySize + poly1305.TagSize]byte + Ephemeral wgkey.Key + Static [wgkey.Size + poly1305.TagSize]byte Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte MAC1 [blake2s.Size128]byte MAC2 [blake2s.Size128]byte diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 8490dee66..99850e32d 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -444,7 +444,7 @@ func TestPickDERPFallback(t *testing.T) { func makeConfigs(t *testing.T, addrs []netaddr.IPPort) []wgcfg.Config { t.Helper() - var privKeys []wgcfg.PrivateKey + var privKeys []wgkey.Private var addresses [][]netaddr.IPPrefix for i := range addrs { @@ -452,7 +452,7 @@ func makeConfigs(t *testing.T, addrs []netaddr.IPPort) []wgcfg.Config { if err != nil { t.Fatal(err) } - privKeys = append(privKeys, wgcfg.PrivateKey(privKey)) + privKeys = append(privKeys, wgkey.Private(privKey)) addresses = append(addresses, []netaddr.IPPrefix{ parseCIDR(t, fmt.Sprintf("1.0.0.%d/32", i+1)), diff --git a/wgengine/wgcfg/config.go b/wgengine/wgcfg/config.go index 08328bcea..aa8850e5d 100644 --- a/wgengine/wgcfg/config.go +++ b/wgengine/wgcfg/config.go @@ -7,6 +7,7 @@ package wgcfg import ( "inet.af/netaddr" + "tailscale.com/types/wgkey" ) // EndpointDiscoSuffix is appended to the hex representation of a peer's discovery key @@ -18,7 +19,7 @@ const EndpointDiscoSuffix = ".disco.tailscale:12345" // It only supports the set of things Tailscale uses. type Config struct { Name string - PrivateKey PrivateKey + PrivateKey wgkey.Private Addresses []netaddr.IPPrefix MTU uint16 DNS []netaddr.IP @@ -26,7 +27,7 @@ type Config struct { } type Peer struct { - PublicKey Key + PublicKey wgkey.Key AllowedIPs []netaddr.IPPrefix Endpoints string // comma-separated host/port pairs: "1.2.3.4:56,[::]:80" PersistentKeepalive uint16 @@ -61,7 +62,7 @@ func (peer Peer) Copy() Peer { } // PeerWithKey returns the Peer with key k and reports whether it was found. -func (config Config) PeerWithKey(k Key) (Peer, bool) { +func (config Config) PeerWithKey(k wgkey.Key) (Peer, bool) { for _, p := range config.Peers { if p.PublicKey == k { return p, true diff --git a/wgengine/wgcfg/device_test.go b/wgengine/wgcfg/device_test.go index 1064b13ab..70b45762a 100644 --- a/wgengine/wgcfg/device_test.go +++ b/wgengine/wgcfg/device_test.go @@ -22,13 +22,13 @@ import ( ) func TestDeviceConfig(t *testing.T) { - newPrivateKey := func() (Key, PrivateKey) { + newPrivateKey := func() (wgkey.Key, wgkey.Private) { t.Helper() pk, err := wgkey.NewPrivate() if err != nil { t.Fatal(err) } - return Key(pk.Public()), PrivateKey(pk) + return wgkey.Key(pk.Public()), wgkey.Private(pk) } k1, pk1 := newPrivateKey() ip1 := netaddr.MustParseIPPrefix("10.0.0.1/32") @@ -40,7 +40,7 @@ func TestDeviceConfig(t *testing.T) { ip3 := netaddr.MustParseIPPrefix("10.0.0.3/32") cfg1 := &Config{ - PrivateKey: PrivateKey(pk1), + PrivateKey: wgkey.Private(pk1), Peers: []Peer{{ PublicKey: k2, AllowedIPs: []netaddr.IPPrefix{ip2}, @@ -48,7 +48,7 @@ func TestDeviceConfig(t *testing.T) { } cfg2 := &Config{ - PrivateKey: PrivateKey(pk2), + PrivateKey: wgkey.Private(pk2), Peers: []Peer{{ PublicKey: k1, AllowedIPs: []netaddr.IPPrefix{ip1}, diff --git a/wgengine/wgcfg/key.go b/wgengine/wgcfg/key.go deleted file mode 100644 index 48601df98..000000000 --- a/wgengine/wgcfg/key.go +++ /dev/null @@ -1,240 +0,0 @@ -// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package wgcfg - -import ( - "bytes" - "crypto/rand" - "crypto/subtle" - "encoding/base64" - "encoding/hex" - "errors" - "fmt" - "strings" - - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/crypto/curve25519" -) - -const KeySize = 32 - -// Key is curve25519 key. -// It is used by WireGuard to represent public and preshared keys. -type Key [KeySize]byte - -// NewPresharedKey generates a new random key. -func NewPresharedKey() (*Key, error) { - var k [KeySize]byte - _, err := rand.Read(k[:]) - if err != nil { - return nil, err - } - return (*Key)(&k), nil -} - -func ParseKey(b64 string) (*Key, error) { return parseKeyBase64(base64.StdEncoding, b64) } - -func ParseHexKey(s string) (Key, error) { - b, err := hex.DecodeString(s) - if err != nil { - return Key{}, &ParseError{"invalid hex key: " + err.Error(), s} - } - if len(b) != KeySize { - return Key{}, &ParseError{fmt.Sprintf("invalid hex key length: %d", len(b)), s} - } - - var key Key - copy(key[:], b) - return key, nil -} - -func ParsePrivateHexKey(v string) (PrivateKey, error) { - k, err := ParseHexKey(v) - if err != nil { - return PrivateKey{}, err - } - pk := PrivateKey(k) - if pk.IsZero() { - // Do not clamp a zero key, pass the zero through - // (much like NaN propagation) so that IsZero reports - // a useful result. - return pk, nil - } - pk.clamp() - return pk, nil -} - -func (k Key) Base64() string { return base64.StdEncoding.EncodeToString(k[:]) } -func (k Key) String() string { return k.ShortString() } -func (k Key) HexString() string { return hex.EncodeToString(k[:]) } -func (k Key) Equal(k2 Key) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } - -func (k *Key) ShortString() string { - long := k.Base64() - return "[" + long[0:5] + "]" -} - -func (k *Key) IsZero() bool { - if k == nil { - return true - } - var zeros Key - return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1 -} - -func (k *Key) MarshalJSON() ([]byte, error) { - if k == nil { - return []byte("null"), nil - } - buf := new(bytes.Buffer) - fmt.Fprintf(buf, `"%x"`, k[:]) - return buf.Bytes(), nil -} - -func (k *Key) UnmarshalJSON(b []byte) error { - if k == nil { - return errors.New("wgcfg.Key: UnmarshalJSON on nil pointer") - } - if len(b) < 3 || b[0] != '"' || b[len(b)-1] != '"' { - return errors.New("wgcfg.Key: UnmarshalJSON not given a string") - } - b = b[1 : len(b)-1] - key, err := ParseHexKey(string(b)) - if err != nil { - return fmt.Errorf("wgcfg.Key: UnmarshalJSON: %v", err) - } - copy(k[:], key[:]) - return nil -} - -func (a *Key) LessThan(b *Key) bool { - for i := range a { - if a[i] < b[i] { - return true - } else if a[i] > b[i] { - return false - } - } - return false -} - -// PrivateKey is curve25519 key. -// It is used by WireGuard to represent private keys. -type PrivateKey [KeySize]byte - -// NewPrivateKey generates a new curve25519 secret key. -// It conforms to the format described on https://cr.yp.to/ecdh.html. -func NewPrivateKey() (PrivateKey, error) { - k, err := NewPresharedKey() - if err != nil { - return PrivateKey{}, err - } - k[0] &= 248 - k[31] = (k[31] & 127) | 64 - return (PrivateKey)(*k), nil -} - -func ParsePrivateKey(b64 string) (*PrivateKey, error) { - k, err := parseKeyBase64(base64.StdEncoding, b64) - return (*PrivateKey)(k), err -} - -func (k *PrivateKey) String() string { return base64.StdEncoding.EncodeToString(k[:]) } -func (k *PrivateKey) HexString() string { return hex.EncodeToString(k[:]) } -func (k *PrivateKey) Equal(k2 PrivateKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } - -func (k *PrivateKey) IsZero() bool { - pk := Key(*k) - return pk.IsZero() -} - -func (k *PrivateKey) clamp() { - k[0] &= 248 - k[31] = (k[31] & 127) | 64 -} - -// Public computes the public key matching this curve25519 secret key. -func (k *PrivateKey) Public() Key { - pk := Key(*k) - if pk.IsZero() { - panic("Tried to generate emptyPrivateKey.Public()") - } - var p [KeySize]byte - curve25519.ScalarBaseMult(&p, (*[KeySize]byte)(k)) - return (Key)(p) -} - -func (k PrivateKey) MarshalText() ([]byte, error) { - buf := new(bytes.Buffer) - fmt.Fprintf(buf, `privkey:%x`, k[:]) - return buf.Bytes(), nil -} - -func (k *PrivateKey) UnmarshalText(b []byte) error { - s := string(b) - if !strings.HasPrefix(s, `privkey:`) { - return errors.New("wgcfg.PrivateKey: UnmarshalText not given a private-key string") - } - s = strings.TrimPrefix(s, `privkey:`) - key, err := ParseHexKey(s) - if err != nil { - return fmt.Errorf("wgcfg.PrivateKey: UnmarshalText: %v", err) - } - copy(k[:], key[:]) - return nil -} - -func (k PrivateKey) SharedSecret(pub Key) (ss [KeySize]byte) { - apk := (*[KeySize]byte)(&pub) - ask := (*[KeySize]byte)(&k) - curve25519.ScalarMult(&ss, ask, apk) //lint:ignore SA1019 Jason says this is OK; match wireguard-go exactyl - return ss -} - -func parseKeyBase64(enc *base64.Encoding, s string) (*Key, error) { - k, err := enc.DecodeString(s) - if err != nil { - return nil, &ParseError{"Invalid key: " + err.Error(), s} - } - if len(k) != KeySize { - return nil, &ParseError{"Keys must decode to exactly 32 bytes", s} - } - var key Key - copy(key[:], k) - return &key, nil -} - -func ParseSymmetricKey(b64 string) (SymmetricKey, error) { - k, err := parseKeyBase64(base64.StdEncoding, b64) - if err != nil { - return SymmetricKey{}, err - } - return SymmetricKey(*k), nil -} - -func ParseSymmetricHexKey(s string) (SymmetricKey, error) { - b, err := hex.DecodeString(s) - if err != nil { - return SymmetricKey{}, &ParseError{"invalid symmetric hex key: " + err.Error(), s} - } - if len(b) != chacha20poly1305.KeySize { - return SymmetricKey{}, &ParseError{fmt.Sprintf("invalid symmetric hex key length: %d", len(b)), s} - } - var key SymmetricKey - copy(key[:], b) - return key, nil -} - -// SymmetricKey is a chacha20poly1305 key. -// It is used by WireGuard to represent pre-shared symmetric keys. -type SymmetricKey [chacha20poly1305.KeySize]byte - -func (k SymmetricKey) Base64() string { return base64.StdEncoding.EncodeToString(k[:]) } -func (k SymmetricKey) String() string { return "sym:" + k.Base64()[:8] } -func (k SymmetricKey) HexString() string { return hex.EncodeToString(k[:]) } -func (k SymmetricKey) IsZero() bool { return k.Equal(SymmetricKey{}) } -func (k SymmetricKey) Equal(k2 SymmetricKey) bool { - return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 -} diff --git a/wgengine/wgcfg/key_test.go b/wgengine/wgcfg/key_test.go deleted file mode 100644 index 709b1afcc..000000000 --- a/wgengine/wgcfg/key_test.go +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package wgcfg - -import ( - "bytes" - "testing" -) - -func TestKeyBasics(t *testing.T) { - k1, err := NewPresharedKey() - if err != nil { - t.Fatal(err) - } - - b, err := k1.MarshalJSON() - if err != nil { - t.Fatal(err) - } - - t.Run("JSON round-trip", func(t *testing.T) { - // should preserve the keys - k2 := new(Key) - if err := k2.UnmarshalJSON(b); err != nil { - t.Fatal(err) - } - if !bytes.Equal(k1[:], k2[:]) { - t.Fatalf("k1 %v != k2 %v", k1[:], k2[:]) - } - if b1, b2 := k1.String(), k2.String(); b1 != b2 { - t.Fatalf("base64-encoded keys do not match: %s, %s", b1, b2) - } - }) - - t.Run("JSON incompatible with PrivateKey", func(t *testing.T) { - k2 := new(PrivateKey) - if err := k2.UnmarshalText(b); err == nil { - t.Fatalf("successfully decoded key as private key") - } - }) - - t.Run("second key", func(t *testing.T) { - // A second call to NewPresharedKey should make a new key. - k3, err := NewPresharedKey() - if err != nil { - t.Fatal(err) - } - if bytes.Equal(k1[:], k3[:]) { - t.Fatalf("k1 %v == k3 %v", k1[:], k3[:]) - } - // Check for obvious comparables to make sure we are not generating bad strings somewhere. - if b1, b2 := k1.String(), k3.String(); b1 == b2 { - t.Fatalf("base64-encoded keys match: %s, %s", b1, b2) - } - }) -} -func TestPrivateKeyBasics(t *testing.T) { - pri, err := NewPrivateKey() - if err != nil { - t.Fatal(err) - } - - b, err := pri.MarshalText() - if err != nil { - t.Fatal(err) - } - - t.Run("JSON round-trip", func(t *testing.T) { - // should preserve the keys - pri2 := new(PrivateKey) - if err := pri2.UnmarshalText(b); err != nil { - t.Fatal(err) - } - if !bytes.Equal(pri[:], pri2[:]) { - t.Fatalf("pri %v != pri2 %v", pri[:], pri2[:]) - } - if b1, b2 := pri.String(), pri2.String(); b1 != b2 { - t.Fatalf("base64-encoded keys do not match: %s, %s", b1, b2) - } - if pub1, pub2 := pri.Public().String(), pri2.Public().String(); pub1 != pub2 { - t.Fatalf("base64-encoded public keys do not match: %s, %s", pub1, pub2) - } - }) - - t.Run("JSON incompatible with Key", func(t *testing.T) { - k2 := new(Key) - if err := k2.UnmarshalJSON(b); err == nil { - t.Fatalf("successfully decoded private key as key") - } - }) - - t.Run("second key", func(t *testing.T) { - // A second call to New should make a new key. - pri3, err := NewPrivateKey() - if err != nil { - t.Fatal(err) - } - if bytes.Equal(pri[:], pri3[:]) { - t.Fatalf("pri %v == pri3 %v", pri[:], pri3[:]) - } - // Check for obvious comparables to make sure we are not generating bad strings somewhere. - if b1, b2 := pri.String(), pri3.String(); b1 == b2 { - t.Fatalf("base64-encoded keys match: %s, %s", b1, b2) - } - if pub1, pub2 := pri.Public().String(), pri3.Public().String(); pub1 == pub2 { - t.Fatalf("base64-encoded public keys match: %s, %s", pub1, pub2) - } - }) -} diff --git a/wgengine/wgcfg/nmcfg/nmcfg.go b/wgengine/wgcfg/nmcfg/nmcfg.go index 443598fdc..8fe1c062a 100644 --- a/wgengine/wgcfg/nmcfg/nmcfg.go +++ b/wgengine/wgcfg/nmcfg/nmcfg.go @@ -18,6 +18,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/logger" "tailscale.com/types/netmap" + "tailscale.com/types/wgkey" "tailscale.com/wgengine/wgcfg" ) @@ -56,7 +57,7 @@ func cidrIsSubnet(node *tailcfg.Node, cidr netaddr.IPPrefix) bool { func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, exitNode tailcfg.StableNodeID) (*wgcfg.Config, error) { cfg := &wgcfg.Config{ Name: "tailscale", - PrivateKey: wgcfg.PrivateKey(nm.PrivateKey), + PrivateKey: wgkey.Private(nm.PrivateKey), Addresses: nm.Addresses, Peers: make([]wgcfg.Peer, 0, len(nm.Peers)), } @@ -71,7 +72,7 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, continue } cfg.Peers = append(cfg.Peers, wgcfg.Peer{ - PublicKey: wgcfg.Key(peer.Key), + PublicKey: wgkey.Key(peer.Key), }) cpeer := &cfg.Peers[len(cfg.Peers)-1] if peer.KeepAlive { diff --git a/wgengine/wgcfg/parser.go b/wgengine/wgcfg/parser.go index ec66eb6bc..318728a6b 100644 --- a/wgengine/wgcfg/parser.go +++ b/wgengine/wgcfg/parser.go @@ -14,6 +14,7 @@ import ( "strings" "inet.af/netaddr" + "tailscale.com/types/wgkey" ) type ParseError struct { @@ -69,15 +70,15 @@ func parseEndpoint(s string) (host string, port uint16, err error) { return host, uint16(uport), nil } -func parseKeyHex(s string) (*Key, error) { +func parseKeyHex(s string) (*wgkey.Key, error) { k, err := hex.DecodeString(s) if err != nil { return nil, &ParseError{"Invalid key: " + err.Error(), s} } - if len(k) != KeySize { + if len(k) != wgkey.Size { return nil, &ParseError{"Keys must decode to exactly 32 bytes", s} } - var key Key + var key wgkey.Key copy(key[:], k) return &key, nil } @@ -142,7 +143,7 @@ func (cfg *Config) handleDeviceLine(key, value string) error { return err } // wireguard-go guarantees not to send zero value; private keys are already clamped. - cfg.PrivateKey = PrivateKey(*k) + cfg.PrivateKey = wgkey.Private(*k) case "listen_port": // ignore case "fwmark": diff --git a/wgengine/wgcfg/writer.go b/wgengine/wgcfg/writer.go index f64b71bf3..9e3462c38 100644 --- a/wgengine/wgcfg/writer.go +++ b/wgengine/wgcfg/writer.go @@ -12,6 +12,7 @@ import ( "strings" "inet.af/netaddr" + "tailscale.com/types/wgkey" ) // ToUAPI writes cfg in UAPI format to w. @@ -41,7 +42,7 @@ func (cfg *Config) ToUAPI(w io.Writer, prev *Config) error { set("private_key", cfg.PrivateKey.HexString()) } - old := make(map[Key]Peer) + old := make(map[wgkey.Key]Peer) for _, p := range prev.Peers { old[p.PublicKey] = p } diff --git a/wgengine/wglog/wglog.go b/wgengine/wglog/wglog.go index fc10de66a..a99ed9c55 100644 --- a/wgengine/wglog/wglog.go +++ b/wgengine/wglog/wglog.go @@ -13,6 +13,7 @@ import ( "github.com/tailscale/wireguard-go/device" "tailscale.com/types/logger" + "tailscale.com/types/wgkey" "tailscale.com/wgengine/wgcfg" ) @@ -86,7 +87,7 @@ func (x *Logger) SetPeers(peers []wgcfg.Peer) { } // wireguardGoString prints p in the same format used by wireguard-go. -func wireguardGoString(k wgcfg.Key) string { +func wireguardGoString(k wgkey.Key) string { base64Key := base64.StdEncoding.EncodeToString(k[:]) abbreviatedKey := "invalid" if len(base64Key) == 44 { diff --git a/wgengine/wglog/wglog_test.go b/wgengine/wglog/wglog_test.go index 077981e41..80ea67449 100644 --- a/wgengine/wglog/wglog_test.go +++ b/wgengine/wglog/wglog_test.go @@ -8,6 +8,7 @@ import ( "fmt" "testing" + "tailscale.com/types/wgkey" "tailscale.com/wgengine/wgcfg" "tailscale.com/wgengine/wglog" ) @@ -34,7 +35,7 @@ func TestLogger(t *testing.T) { } x := wglog.NewLogger(logf) - key, err := wgcfg.ParseHexKey("20c4c1ae54e1fd37cab6e9a532ca20646aff496796cc41d4519560e5e82bee53") + key, err := wgkey.ParseHex("20c4c1ae54e1fd37cab6e9a532ca20646aff496796cc41d4519560e5e82bee53") if err != nil { t.Fatal(err) }