From 37c150aee1d9d9a805d4f520ec808fffee84b439 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Thu, 28 Oct 2021 15:42:50 -0700 Subject: [PATCH] derp: use new node key type. Update #3206 Signed-off-by: David Anderson --- cmd/derper/derper.go | 2 +- cmd/derper/mesh.go | 4 +- cmd/derpprobe/derpprobe.go | 2 +- cmd/tailscale/depaware.txt | 2 +- cmd/tailscaled/debug.go | 4 +- cmd/tailscaled/depaware.txt | 2 +- derp/derp_client.go | 74 +++++------ derp/derp_server.go | 182 ++++++++++----------------- derp/derp_test.go | 119 +++++++++--------- derp/derphttp/derphttp_client.go | 28 ++--- derp/derphttp/derphttp_server.go | 4 +- derp/derphttp/derphttp_test.go | 8 +- derp/derphttp/mesh_client.go | 12 +- tstest/integration/integration.go | 7 +- wgengine/magicsock/magicsock.go | 12 +- wgengine/magicsock/magicsock_test.go | 6 +- 16 files changed, 199 insertions(+), 269 deletions(-) diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index 4fd6a0dbd..a8b247eb3 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -143,7 +143,7 @@ func main() { serveTLS := tsweb.IsProd443(*addr) - s := derp.NewServer(cfg.PrivateKey.AsPrivate(), log.Printf) + s := derp.NewServer(cfg.PrivateKey, log.Printf) s.SetVerifyClient(*verifyClients) if *meshPSKFile != "" { diff --git a/cmd/derper/mesh.go b/cmd/derper/mesh.go index 9c9a936f5..24ea83e5f 100644 --- a/cmd/derper/mesh.go +++ b/cmd/derper/mesh.go @@ -69,8 +69,8 @@ func startMeshWithHost(s *derp.Server, host string) error { return d.DialContext(ctx, network, addr) }) - add := func(k key.Public) { s.AddPacketForwarder(k, c) } - remove := func(k key.Public) { s.RemovePacketForwarder(k, c) } + add := func(k key.NodePublic) { s.AddPacketForwarder(k, c) } + remove := func(k key.NodePublic) { s.RemovePacketForwarder(k, c) } go c.RunWatchConnectionLoop(context.Background(), s.PublicKey(), logf, add, remove) return nil } diff --git a/cmd/derpprobe/derpprobe.go b/cmd/derpprobe/derpprobe.go index 89e1224e0..3adde9f81 100644 --- a/cmd/derpprobe/derpprobe.go +++ b/cmd/derpprobe/derpprobe.go @@ -344,7 +344,7 @@ func probeNodePair(ctx context.Context, dm *tailcfg.DERPMap, from, to *tailcfg.D } func newConn(ctx context.Context, dm *tailcfg.DERPMap, n *tailcfg.DERPNode) (*derphttp.Client, error) { - priv := key.NewPrivate() + priv := key.NewNode() dc := derphttp.NewRegionClient(priv, log.Printf, func() *tailcfg.DERPRegion { rid := n.RegionID return &tailcfg.DERPRegion{ diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index e207669a4..c6a82000f 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -86,7 +86,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ golang.org/x/crypto/curve25519 from crypto/tls+ golang.org/x/crypto/hkdf from crypto/tls - golang.org/x/crypto/nacl/box from tailscale.com/derp+ + golang.org/x/crypto/nacl/box from tailscale.com/types/key golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/poly1305 from golang.org/x/crypto/chacha20poly1305 golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ diff --git a/cmd/tailscaled/debug.go b/cmd/tailscaled/debug.go index ccd30c1d4..596b00b11 100644 --- a/cmd/tailscaled/debug.go +++ b/cmd/tailscaled/debug.go @@ -193,8 +193,8 @@ func checkDerp(ctx context.Context, derpRegion string) error { panic("unreachable") } - priv1 := key.NewPrivate() - priv2 := key.NewPrivate() + priv1 := key.NewNode() + priv2 := key.NewNode() c1 := derphttp.NewRegionClient(priv1, log.Printf, getRegion) c2 := derphttp.NewRegionClient(priv2, log.Printf, getRegion) diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 8e1fe5154..50ca6877c 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -257,7 +257,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ golang.org/x/crypto/curve25519 from crypto/tls+ golang.org/x/crypto/hkdf from crypto/tls - golang.org/x/crypto/nacl/box from tailscale.com/derp+ + golang.org/x/crypto/nacl/box from tailscale.com/types/key+ golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/poly1305 from golang.org/x/crypto/chacha20poly1305+ golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ diff --git a/derp/derp_client.go b/derp/derp_client.go index 75a148298..6982f38af 100644 --- a/derp/derp_client.go +++ b/derp/derp_client.go @@ -6,7 +6,6 @@ package derp import ( "bufio" - crand "crypto/rand" "encoding/binary" "encoding/json" "errors" @@ -15,7 +14,7 @@ import ( "sync" "time" - "golang.org/x/crypto/nacl/box" + "go4.org/mem" "golang.org/x/time/rate" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -23,9 +22,9 @@ import ( // Client is a DERP client. type Client struct { - serverKey key.Public // of the DERP server; not a machine or node key - privateKey key.Private - publicKey key.Public // of privateKey + serverKey key.NodePublic // of the DERP server; not a machine or node key + privateKey key.NodePrivate + publicKey key.NodePublic // of privateKey logf logger.Logf nc Conn br *bufio.Reader @@ -54,7 +53,7 @@ func (f clientOptFunc) update(o *clientOpt) { f(o) } // clientOpt are the options passed to newClient. type clientOpt struct { MeshKey string - ServerPub key.Public + ServerPub key.NodePublic CanAckPings bool IsProber bool } @@ -71,7 +70,7 @@ func IsProber(v bool) ClientOpt { return clientOptFunc(func(o *clientOpt) { o.Is // ServerPublicKey returns a ClientOpt to declare that the server's DERP public key is known. // If key is the zero value, the returned ClientOpt is a no-op. -func ServerPublicKey(key key.Public) ClientOpt { +func ServerPublicKey(key key.NodePublic) ClientOpt { return clientOptFunc(func(o *clientOpt) { o.ServerPub = key }) } @@ -81,7 +80,7 @@ func CanAckPings(v bool) ClientOpt { return clientOptFunc(func(o *clientOpt) { o.CanAckPings = v }) } -func NewClient(privateKey key.Private, nc Conn, brw *bufio.ReadWriter, logf logger.Logf, opts ...ClientOpt) (*Client, error) { +func NewClient(privateKey key.NodePrivate, nc Conn, brw *bufio.ReadWriter, logf logger.Logf, opts ...ClientOpt) (*Client, error) { var opt clientOpt for _, o := range opts { if o == nil { @@ -92,7 +91,7 @@ func NewClient(privateKey key.Private, nc Conn, brw *bufio.ReadWriter, logf logg return newClient(privateKey, nc, brw, logf, opt) } -func newClient(privateKey key.Private, nc Conn, brw *bufio.ReadWriter, logf logger.Logf, opt clientOpt) (*Client, error) { +func newClient(privateKey key.NodePrivate, nc Conn, brw *bufio.ReadWriter, logf logger.Logf, opt clientOpt) (*Client, error) { c := &Client{ privateKey: privateKey, publicKey: privateKey.Public(), @@ -130,7 +129,7 @@ func (c *Client) recvServerKey() error { if flen < uint32(len(buf)) || t != frameServerKey || string(buf[:len(magic)]) != magic { return errors.New("invalid server greeting") } - copy(c.serverKey[:], buf[len(magic):]) + c.serverKey = key.NodePublicFromRaw32(mem.B(buf[len(magic):])) return nil } @@ -143,13 +142,9 @@ func (c *Client) parseServerInfo(b []byte) (*serverInfo, error) { if fl > maxLength { return nil, fmt.Errorf("long serverInfo frame") } - // TODO: add a read-nonce-and-box helper - var nonce [nonceLen]byte - copy(nonce[:], b) - msgbox := b[nonceLen:] - msg, ok := box.Open(nil, msgbox, &nonce, c.serverKey.B32(), c.privateKey.B32()) + msg, ok := c.privateKey.OpenFrom(c.serverKey, b) if !ok { - return nil, fmt.Errorf("failed to open naclbox from server key %x", c.serverKey[:]) + return nil, fmt.Errorf("failed to open naclbox from server key %s", c.serverKey) } info := new(serverInfo) if err := json.Unmarshal(msg, info); err != nil { @@ -176,10 +171,6 @@ type clientInfo struct { } func (c *Client) sendClientKey() error { - var nonce [nonceLen]byte - if _, err := crand.Read(nonce[:]); err != nil { - return err - } msg, err := json.Marshal(clientInfo{ Version: ProtocolVersion, MeshKey: c.meshKey, @@ -189,24 +180,23 @@ func (c *Client) sendClientKey() error { if err != nil { return err } - msgbox := box.Seal(nil, msg, &nonce, c.serverKey.B32(), c.privateKey.B32()) + msgbox := c.privateKey.SealTo(c.serverKey, msg) - buf := make([]byte, 0, nonceLen+keyLen+len(msgbox)) - buf = append(buf, c.publicKey[:]...) - buf = append(buf, nonce[:]...) + buf := make([]byte, 0, keyLen+len(msgbox)) + buf = c.publicKey.AppendTo(buf) buf = append(buf, msgbox...) return writeFrame(c.bw, frameClientInfo, buf) } // ServerPublicKey returns the server's public key. -func (c *Client) ServerPublicKey() key.Public { return c.serverKey } +func (c *Client) ServerPublicKey() key.NodePublic { return c.serverKey } // Send sends a packet to the Tailscale node identified by dstKey. // // It is an error if the packet is larger than 64KB. -func (c *Client) Send(dstKey key.Public, pkt []byte) error { return c.send(dstKey, pkt) } +func (c *Client) Send(dstKey key.NodePublic, pkt []byte) error { return c.send(dstKey, pkt) } -func (c *Client) send(dstKey key.Public, pkt []byte) (ret error) { +func (c *Client) send(dstKey key.NodePublic, pkt []byte) (ret error) { defer func() { if ret != nil { ret = fmt.Errorf("derp.Send: %w", ret) @@ -220,15 +210,15 @@ func (c *Client) send(dstKey key.Public, pkt []byte) (ret error) { c.wmu.Lock() defer c.wmu.Unlock() if c.rate != nil { - pktLen := frameHeaderLen + len(dstKey) + len(pkt) + pktLen := frameHeaderLen + dstKey.RawLen() + len(pkt) if !c.rate.AllowN(time.Now(), pktLen) { return nil // drop } } - if err := writeFrameHeader(c.bw, frameSendPacket, uint32(len(dstKey)+len(pkt))); err != nil { + if err := writeFrameHeader(c.bw, frameSendPacket, uint32(dstKey.RawLen()+len(pkt))); err != nil { return err } - if _, err := c.bw.Write(dstKey[:]); err != nil { + if _, err := c.bw.Write(dstKey.AppendTo(nil)); err != nil { return err } if _, err := c.bw.Write(pkt); err != nil { @@ -237,7 +227,7 @@ func (c *Client) send(dstKey key.Public, pkt []byte) (ret error) { return c.bw.Flush() } -func (c *Client) ForwardPacket(srcKey, dstKey key.Public, pkt []byte) (err error) { +func (c *Client) ForwardPacket(srcKey, dstKey key.NodePublic, pkt []byte) (err error) { defer func() { if err != nil { err = fmt.Errorf("derp.ForwardPacket: %w", err) @@ -257,10 +247,10 @@ func (c *Client) ForwardPacket(srcKey, dstKey key.Public, pkt []byte) (err error if err := writeFrameHeader(c.bw, frameForwardPacket, uint32(keyLen*2+len(pkt))); err != nil { return err } - if _, err := c.bw.Write(srcKey[:]); err != nil { + if _, err := c.bw.Write(srcKey.AppendTo(nil)); err != nil { return err } - if _, err := c.bw.Write(dstKey[:]); err != nil { + if _, err := c.bw.Write(dstKey.AppendTo(nil)); err != nil { return err } if _, err := c.bw.Write(pkt); err != nil { @@ -322,10 +312,10 @@ func (c *Client) WatchConnectionChanges() error { // ClosePeer asks the server to close target's TCP connection. // It's a fatal error if the client wasn't created using MeshKey. -func (c *Client) ClosePeer(target key.Public) error { +func (c *Client) ClosePeer(target key.NodePublic) error { c.wmu.Lock() defer c.wmu.Unlock() - return writeFrame(c.bw, frameClosePeer, target[:]) + return writeFrame(c.bw, frameClosePeer, target.AppendTo(nil)) } // ReceivedMessage represents a type returned by Client.Recv. Unless @@ -338,7 +328,7 @@ type ReceivedMessage interface { // ReceivedPacket is a ReceivedMessage representing an incoming packet. type ReceivedPacket struct { - Source key.Public + Source key.NodePublic // Data is the received packet bytes. It aliases the memory // passed to Client.Recv. Data []byte @@ -349,13 +339,13 @@ func (ReceivedPacket) msg() {} // PeerGoneMessage is a ReceivedMessage that indicates that the client // identified by the underlying public key had previously sent you a // packet but has now disconnected from the server. -type PeerGoneMessage key.Public +type PeerGoneMessage key.NodePublic func (PeerGoneMessage) msg() {} // PeerPresentMessage is a ReceivedMessage that indicates that the client // is connected to the server. (Only used by trusted mesh clients) -type PeerPresentMessage key.Public +type PeerPresentMessage key.NodePublic func (PeerPresentMessage) msg() {} @@ -516,8 +506,7 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro c.logf("[unexpected] dropping short peerGone frame from DERP server") continue } - var pg PeerGoneMessage - copy(pg[:], b[:keyLen]) + pg := PeerGoneMessage(key.NodePublicFromRaw32(mem.B(b[:keyLen]))) return pg, nil case framePeerPresent: @@ -525,8 +514,7 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro c.logf("[unexpected] dropping short peerPresent frame from DERP server") continue } - var pg PeerPresentMessage - copy(pg[:], b[:keyLen]) + pg := PeerPresentMessage(key.NodePublicFromRaw32(mem.B(b[:keyLen]))) return pg, nil case frameRecvPacket: @@ -535,7 +523,7 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro c.logf("[unexpected] dropping short packet from DERP server") continue } - copy(rp.Source[:], b[:keyLen]) + rp.Source = key.NodePublicFromRaw32(mem.B(b[:keyLen])) rp.Data = b[keyLen:n] return rp, nil diff --git a/derp/derp_server.go b/derp/derp_server.go index cccf4eacd..14790f1d2 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -34,7 +34,6 @@ import ( "time" "go4.org/mem" - "golang.org/x/crypto/nacl/box" "golang.org/x/sync/errgroup" "golang.org/x/time/rate" "inet.af/netaddr" @@ -52,7 +51,7 @@ var debug, _ = strconv.ParseBool(os.Getenv("DERP_DEBUG_LOGS")) // verboseDropKeys is the set of destination public keys that should // verbosely log whenever DERP drops a packet. -var verboseDropKeys = map[key.Public]bool{} +var verboseDropKeys = map[key.NodePublic]bool{} func init() { keys := os.Getenv("TS_DEBUG_VERBOSE_DROPS") @@ -60,7 +59,7 @@ func init() { return } for _, keyStr := range strings.Split(keys, ",") { - k, err := key.NewPublicFromHexMem(mem.S(keyStr)) + k, err := key.ParseNodePublicUntyped(mem.S(keyStr)) if err != nil { log.Printf("ignoring invalid debug key %q: %v", keyStr, err) } else { @@ -99,8 +98,8 @@ type Server struct { // before failing when writing to a client. WriteTimeout time.Duration - privateKey key.Private - publicKey key.Public + privateKey key.NodePrivate + publicKey key.NodePublic logf logger.Logf memSys0 uint64 // runtime.MemStats.Sys at start (or early-ish) meshKey string @@ -146,22 +145,22 @@ type Server struct { mu sync.Mutex closed bool netConns map[Conn]chan struct{} // chan is closed when conn closes - clients map[key.Public]clientSet + clients map[key.NodePublic]clientSet watchers map[*sclient]bool // mesh peer -> true // clientsMesh tracks all clients in the cluster, both locally // and to mesh peers. If the value is nil, that means the // peer is only local (and thus in the clients Map, but not // remote). If the value is non-nil, it's remote (+ maybe also // local). - clientsMesh map[key.Public]PacketForwarder + clientsMesh map[key.NodePublic]PacketForwarder // sentTo tracks which peers have sent to which other peers, // and at which connection number. This isn't on sclient // because it includes intra-region forwarded packets as the // src. - sentTo map[key.Public]map[key.Public]int64 // src => dst => dst's latest sclient.connNum + sentTo map[key.NodePublic]map[key.NodePublic]int64 // src => dst => dst's latest sclient.connNum // maps from netaddr.IPPort to a client's public key - keyOfAddr map[netaddr.IPPort]key.Public + keyOfAddr map[netaddr.IPPort]key.NodePublic } // clientSet represents 1 or more *sclients. @@ -277,7 +276,7 @@ func (s *dupClientSet) removeClient(c *sclient) bool { // is a multiForwarder, which this package creates as needed if a // public key gets more than one PacketForwarder registered for it. type PacketForwarder interface { - ForwardPacket(src, dst key.Public, payload []byte) error + ForwardPacket(src, dst key.NodePublic, payload []byte) error } // Conn is the subset of the underlying net.Conn the DERP Server needs. @@ -294,7 +293,7 @@ type Conn interface { // NewServer returns a new DERP server. It doesn't listen on its own. // Connections are given to it via Server.Accept. -func NewServer(privateKey key.Private, logf logger.Logf) *Server { +func NewServer(privateKey key.NodePrivate, logf logger.Logf) *Server { var ms runtime.MemStats runtime.ReadMemStats(&ms) @@ -306,14 +305,14 @@ func NewServer(privateKey key.Private, logf logger.Logf) *Server { packetsRecvByKind: metrics.LabelMap{Label: "kind"}, packetsDroppedReason: metrics.LabelMap{Label: "reason"}, packetsDroppedType: metrics.LabelMap{Label: "type"}, - clients: map[key.Public]clientSet{}, - clientsMesh: map[key.Public]PacketForwarder{}, + clients: map[key.NodePublic]clientSet{}, + clientsMesh: map[key.NodePublic]PacketForwarder{}, netConns: map[Conn]chan struct{}{}, memSys0: ms.Sys, watchers: map[*sclient]bool{}, - sentTo: map[key.Public]map[key.Public]int64{}, + sentTo: map[key.NodePublic]map[key.NodePublic]int64{}, avgQueueDuration: new(uint64), - keyOfAddr: map[netaddr.IPPort]key.Public{}, + keyOfAddr: map[netaddr.IPPort]key.NodePublic{}, } s.initMetacert() s.packetsRecvDisco = s.packetsRecvByKind.Get("disco") @@ -353,10 +352,10 @@ func (s *Server) HasMeshKey() bool { return s.meshKey != "" } func (s *Server) MeshKey() string { return s.meshKey } // PrivateKey returns the server's private key. -func (s *Server) PrivateKey() key.Private { return s.privateKey } +func (s *Server) PrivateKey() key.NodePrivate { return s.privateKey } // PublicKey returns the server's public key. -func (s *Server) PublicKey() key.Public { return s.publicKey } +func (s *Server) PublicKey() key.NodePublic { return s.publicKey } // Close closes the server and waits for the connections to disconnect. func (s *Server) Close() error { @@ -447,7 +446,7 @@ func (s *Server) initMetacert() { tmpl := &x509.Certificate{ SerialNumber: big.NewInt(ProtocolVersion), Subject: pkix.Name{ - CommonName: fmt.Sprintf("derpkey%x", s.publicKey[:]), + CommonName: fmt.Sprintf("derpkey%s", s.publicKey.UntypedHexString()), }, // Windows requires NotAfter and NotBefore set: NotAfter: time.Now().Add(30 * 24 * time.Hour), @@ -515,7 +514,7 @@ func (s *Server) registerClient(c *sclient) { // presence changed. // // s.mu must be held. -func (s *Server) broadcastPeerStateChangeLocked(peer key.Public, present bool) { +func (s *Server) broadcastPeerStateChangeLocked(peer key.NodePublic, present bool) { for w := range s.watchers { w.peerStateChange = append(w.peerStateChange, peerConnState{peer: peer, present: present}) go w.requestMeshUpdate() @@ -577,7 +576,7 @@ func (s *Server) unregisterClient(c *sclient) { // key has sent to previously (whether those sends were from a local // client or forwarded). It must only be called after the key has // been removed from clientsMesh. -func (s *Server) notePeerGoneFromRegionLocked(key key.Public) { +func (s *Server) notePeerGoneFromRegionLocked(key key.NodePublic) { if _, ok := s.clientsMesh[key]; ok { panic("usage") } @@ -663,7 +662,7 @@ func (s *Server) accept(nc Conn, brw *bufio.ReadWriter, remoteAddr string, connN connectedAt: time.Now(), sendQueue: make(chan pkt, perClientSendQueueDepth), discoSendQueue: make(chan pkt, perClientSendQueueDepth), - peerGone: make(chan key.Public), + peerGone: make(chan key.NodePublic), canMesh: clientInfo.MeshKey != "" && clientInfo.MeshKey == s.meshKey, } @@ -774,8 +773,8 @@ func (c *sclient) handleFrameClosePeer(ft frameType, fl uint32) error { if !c.canMesh { return fmt.Errorf("insufficient permissions") } - var targetKey key.Public - if _, err := io.ReadFull(c.br, targetKey[:]); err != nil { + var targetKey key.NodePublic + if err := targetKey.ReadRawWithoutAllocating(c.br); err != nil { return err } s := c.s @@ -845,10 +844,10 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error { // notePeerSendLocked records that src sent to dst. We keep track of // that so when src disconnects, we can tell dst (if it's still // around) that src is gone (a peerGone frame). -func (s *Server) notePeerSendLocked(src key.Public, dst *sclient) { +func (s *Server) notePeerSendLocked(src key.NodePublic, dst *sclient) { m, ok := s.sentTo[src] if !ok { - m = map[key.Public]int64{} + m = map[key.NodePublic]int64{} s.sentTo[src] = m } m[dst.key] = dst.connNum @@ -919,7 +918,7 @@ const ( dropReasonDupClient // the public key is connected 2+ times (active/active, fighting) ) -func (s *Server) recordDrop(packetBytes []byte, srcKey, dstKey key.Public, reason dropReason) { +func (s *Server) recordDrop(packetBytes []byte, srcKey, dstKey key.NodePublic, reason dropReason) { s.packetsDropped.Add(1) s.packetsDroppedReasonCounters[reason].Add(1) if disco.LooksLikeDiscoWrapper(packetBytes) { @@ -982,7 +981,7 @@ func (c *sclient) sendPkt(dst *sclient, p pkt) error { // requestPeerGoneWrite sends a request to write a "peer gone" frame // that the provided peer has disconnected. It blocks until either the // write request is scheduled, or the client has closed. -func (c *sclient) requestPeerGoneWrite(peer key.Public) { +func (c *sclient) requestPeerGoneWrite(peer key.NodePublic) { select { case c.peerGone <- peer: case <-c.done: @@ -999,7 +998,7 @@ func (c *sclient) requestMeshUpdate() { } } -func (s *Server) verifyClient(clientKey key.Public, info *clientInfo) error { +func (s *Server) verifyClient(clientKey key.NodePublic, info *clientInfo) error { if !s.verifyClients { return nil } @@ -1007,10 +1006,10 @@ func (s *Server) verifyClient(clientKey key.Public, info *clientInfo) error { if err != nil { return fmt.Errorf("failed to query local tailscaled status: %w", err) } - if clientKey == status.Self.PublicKey { + if clientKey == key.NodePublicFromRaw32(mem.B(status.Self.PublicKey[:])) { return nil } - if _, exists := status.Peer[clientKey]; !exists { + if _, exists := status.Peer[clientKey.AsPublic()]; !exists { return fmt.Errorf("client %v not in set of peers", clientKey) } // TODO(bradfitz): add policy for configurable bandwidth rate per client? @@ -1018,9 +1017,9 @@ func (s *Server) verifyClient(clientKey key.Public, info *clientInfo) error { } func (s *Server) sendServerKey(lw *lazyBufioWriter) error { - buf := make([]byte, 0, len(magic)+len(s.publicKey)) + buf := make([]byte, 0, len(magic)+s.publicKey.RawLen()) buf = append(buf, magic...) - buf = append(buf, s.publicKey[:]...) + buf = s.publicKey.AppendTo(buf) err := writeFrame(lw.bw(), frameServerKey, buf) lw.Flush() // redundant (no-op) flush to release bufio.Writer return err @@ -1084,21 +1083,14 @@ type serverInfo struct { TokenBucketBytesBurst int `json:",omitempty"` } -func (s *Server) sendServerInfo(bw *lazyBufioWriter, clientKey key.Public) error { - var nonce [24]byte - if _, err := crand.Read(nonce[:]); err != nil { - return err - } +func (s *Server) sendServerInfo(bw *lazyBufioWriter, clientKey key.NodePublic) error { msg, err := json.Marshal(serverInfo{Version: ProtocolVersion}) if err != nil { return err } - msgbox := box.Seal(nil, msg, &nonce, clientKey.B32(), s.privateKey.B32()) - if err := writeFrameHeader(bw.bw(), frameServerInfo, nonceLen+uint32(len(msgbox))); err != nil { - return err - } - if _, err := bw.Write(nonce[:]); err != nil { + msgbox := s.privateKey.SealTo(clientKey, msg) + if err := writeFrameHeader(bw.bw(), frameServerInfo, uint32(len(msgbox))); err != nil { return err } if _, err := bw.Write(msgbox); err != nil { @@ -1110,7 +1102,7 @@ func (s *Server) sendServerInfo(bw *lazyBufioWriter, clientKey key.Public) error // recvClientKey reads the frameClientInfo frame from the client (its // proof of identity) upon its initial connection. It should be // considered especially untrusted at this point. -func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.Public, info *clientInfo, err error) { +func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.NodePublic, info *clientInfo, err error) { fl, err := readFrameTypeHeader(br, frameClientInfo) if err != nil { return zpub, nil, err @@ -1124,21 +1116,17 @@ func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.Public, info *cl if fl > 256<<10 { return zpub, nil, errors.New("long client info") } - if _, err := io.ReadFull(br, clientKey[:]); err != nil { + if err := clientKey.ReadRawWithoutAllocating(br); err != nil { return zpub, nil, err } - var nonce [24]byte - if _, err := io.ReadFull(br, nonce[:]); err != nil { - return zpub, nil, fmt.Errorf("nonce: %v", err) - } - msgLen := int(fl - minLen) + msgLen := int(fl - keyLen) msgbox := make([]byte, msgLen) if _, err := io.ReadFull(br, msgbox); err != nil { return zpub, nil, fmt.Errorf("msgbox: %v", err) } - msg, ok := box.Open(nil, msgbox, &nonce, (*[32]byte)(&clientKey), s.privateKey.B32()) + msg, ok := s.privateKey.OpenFrom(clientKey, msgbox) if !ok { - return zpub, nil, fmt.Errorf("msgbox: cannot open len=%d with client key %x", msgLen, clientKey[:]) + return zpub, nil, fmt.Errorf("msgbox: cannot open len=%d with client key %s", msgLen, clientKey) } info = new(clientInfo) if err := json.Unmarshal(msg, info); err != nil { @@ -1147,11 +1135,11 @@ func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.Public, info *cl return clientKey, info, nil } -func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.Public, contents []byte, err error) { +func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.NodePublic, contents []byte, err error) { if frameLen < keyLen { return zpub, nil, errors.New("short send packet frame") } - if err := readPublicKey(br, &dstKey); err != nil { + if err := dstKey.ReadRawWithoutAllocating(br); err != nil { return zpub, nil, err } packetLen := frameLen - keyLen @@ -1173,16 +1161,16 @@ func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.Publi } // zpub is the key.Public zero value. -var zpub key.Public +var zpub key.NodePublic -func (s *Server) recvForwardPacket(br *bufio.Reader, frameLen uint32) (srcKey, dstKey key.Public, contents []byte, err error) { +func (s *Server) recvForwardPacket(br *bufio.Reader, frameLen uint32) (srcKey, dstKey key.NodePublic, contents []byte, err error) { if frameLen < keyLen*2 { return zpub, zpub, nil, errors.New("short send packet frame") } - if _, err := io.ReadFull(br, srcKey[:]); err != nil { + if err := srcKey.ReadRawWithoutAllocating(br); err != nil { return zpub, zpub, nil, err } - if _, err := io.ReadFull(br, dstKey[:]); err != nil { + if err := dstKey.ReadRawWithoutAllocating(br); err != nil { return zpub, zpub, nil, err } packetLen := frameLen - keyLen*2 @@ -1206,19 +1194,19 @@ type sclient struct { connNum int64 // process-wide unique counter, incremented each Accept s *Server nc Conn - key key.Public + key key.NodePublic info clientInfo logf logger.Logf - done <-chan struct{} // closed when connection closes - remoteAddr string // usually ip:port from net.Conn.RemoteAddr().String() - remoteIPPort netaddr.IPPort // zero if remoteAddr is not ip:port. - sendQueue chan pkt // packets queued to this client; never closed - discoSendQueue chan pkt // important packets queued to this client; never closed - peerGone chan key.Public // write request that a previous sender has disconnected (not used by mesh peers) - meshUpdate chan struct{} // write request to write peerStateChange - canMesh bool // clientInfo had correct mesh token for inter-region routing - isDup syncs.AtomicBool // whether more than 1 sclient for key is connected - isDisabled syncs.AtomicBool // whether sends to this peer are disabled due to active/active dups + done <-chan struct{} // closed when connection closes + remoteAddr string // usually ip:port from net.Conn.RemoteAddr().String() + remoteIPPort netaddr.IPPort // zero if remoteAddr is not ip:port. + sendQueue chan pkt // packets queued to this client; never closed + discoSendQueue chan pkt // important packets queued to this client; never closed + peerGone chan key.NodePublic // write request that a previous sender has disconnected (not used by mesh peers) + meshUpdate chan struct{} // write request to write peerStateChange + canMesh bool // clientInfo had correct mesh token for inter-region routing + isDup syncs.AtomicBool // whether more than 1 sclient for key is connected + isDisabled syncs.AtomicBool // whether sends to this peer are disabled due to active/active dups // replaceLimiter controls how quickly two connections with // the same client key can kick each other off the server by @@ -1245,14 +1233,14 @@ type sclient struct { // peerConnState represents whether a peer is connected to the server // or not. type peerConnState struct { - peer key.Public + peer key.NodePublic present bool } // pkt is a request to write a data frame to an sclient. type pkt struct { // src is the who's the sender of the packet. - src key.Public + src key.NodePublic // enqueuedAt is when a packet was put onto a queue before it was sent, // and is used for reporting metrics on the duration of packets in the queue. @@ -1397,23 +1385,23 @@ func (c *sclient) sendKeepAlive() error { } // sendPeerGone sends a peerGone frame, without flushing. -func (c *sclient) sendPeerGone(peer key.Public) error { +func (c *sclient) sendPeerGone(peer key.NodePublic) error { c.s.peerGoneFrames.Add(1) c.setWriteDeadline() if err := writeFrameHeader(c.bw.bw(), framePeerGone, keyLen); err != nil { return err } - _, err := c.bw.Write(peer[:]) + _, err := c.bw.Write(peer.AppendTo(nil)) return err } // sendPeerPresent sends a peerPresent frame, without flushing. -func (c *sclient) sendPeerPresent(peer key.Public) error { +func (c *sclient) sendPeerPresent(peer key.NodePublic) error { c.setWriteDeadline() if err := writeFrameHeader(c.bw.bw(), framePeerPresent, keyLen); err != nil { return err } - _, err := c.bw.Write(peer[:]) + _, err := c.bw.Write(peer.AppendTo(nil)) return err } @@ -1465,7 +1453,7 @@ func (c *sclient) sendMeshUpdates() error { // DERPv2. The bytes of contents are only valid until this function // returns, do not retain slices. // It does not flush its bufio.Writer. -func (c *sclient) sendPacket(srcKey key.Public, contents []byte) (err error) { +func (c *sclient) sendPacket(srcKey key.NodePublic, contents []byte) (err error) { defer func() { // Stats update. if err != nil { @@ -1481,14 +1469,13 @@ func (c *sclient) sendPacket(srcKey key.Public, contents []byte) (err error) { withKey := !srcKey.IsZero() pktLen := len(contents) if withKey { - pktLen += len(srcKey) + pktLen += srcKey.RawLen() } if err = writeFrameHeader(c.bw.bw(), frameRecvPacket, uint32(pktLen)); err != nil { return err } if withKey { - err := writePublicKey(c.bw.bw(), &srcKey) - if err != nil { + if err := srcKey.WriteRawWithoutAllocating(c.bw.bw()); err != nil { return err } } @@ -1498,7 +1485,7 @@ func (c *sclient) sendPacket(srcKey key.Public, contents []byte) (err error) { // AddPacketForwarder registers fwd as a packet forwarder for dst. // fwd must be comparable. -func (s *Server) AddPacketForwarder(dst key.Public, fwd PacketForwarder) { +func (s *Server) AddPacketForwarder(dst key.NodePublic, fwd PacketForwarder) { s.mu.Lock() defer s.mu.Unlock() if prev, ok := s.clientsMesh[dst]; ok { @@ -1530,7 +1517,7 @@ func (s *Server) AddPacketForwarder(dst key.Public, fwd PacketForwarder) { // RemovePacketForwarder removes fwd as a packet forwarder for dst. // fwd must be comparable. -func (s *Server) RemovePacketForwarder(dst key.Public, fwd PacketForwarder) { +func (s *Server) RemovePacketForwarder(dst key.NodePublic, fwd PacketForwarder) { s.mu.Lock() defer s.mu.Unlock() v, ok := s.clientsMesh[dst] @@ -1592,7 +1579,7 @@ func (m multiForwarder) maxVal() (max uint8) { return } -func (m multiForwarder) ForwardPacket(src, dst key.Public, payload []byte) error { +func (m multiForwarder) ForwardPacket(src, dst key.NodePublic, payload []byte) error { var fwd PacketForwarder var lowest uint8 for k, v := range m { @@ -1692,37 +1679,6 @@ func (s *Server) ConsistencyCheck() error { return errors.New(strings.Join(errs, ", ")) } -// readPublicKey reads key from br. -// It is ~4x slower than io.ReadFull(br, key), -// but it prevents key from escaping and thus being allocated. -// If io.ReadFull(br, key) does not cause key to escape, use that instead. -func readPublicKey(br *bufio.Reader, key *key.Public) error { - // Do io.ReadFull(br, key), but one byte at a time, to avoid allocation. - for i := range key { - b, err := br.ReadByte() - if err != nil { - return err - } - key[i] = b - } - return nil -} - -// writePublicKey writes key to bw. -// It is ~3x slower than bw.Write(key[:]), -// but it prevents key from escaping and thus being allocated. -// If bw.Write(key[:]) does not cause key to escape, use that instead. -func writePublicKey(bw *bufio.Writer, key *key.Public) error { - // Do bw.Write(key[:]), but one byte at a time to avoid allocation. - for _, b := range key { - err := bw.WriteByte(b) - if err != nil { - return err - } - } - return nil -} - const minTimeBetweenLogs = 2 * time.Second // BytesSentRecv records the number of bytes that have been sent since the last traffic check @@ -1731,7 +1687,7 @@ type BytesSentRecv struct { Sent uint64 Recv uint64 // Key is the public key of the client which sent/received these bytes. - Key key.Public + Key key.NodePublic } // parseSSOutput parses the output from the specific call to ss in ServeDebugTraffic. diff --git a/derp/derp_test.go b/derp/derp_test.go index 7434d9b86..8555bdf5a 100644 --- a/derp/derp_test.go +++ b/derp/derp_test.go @@ -8,7 +8,6 @@ import ( "bufio" "bytes" "context" - crand "crypto/rand" "crypto/x509" "encoding/json" "errors" @@ -23,20 +22,13 @@ import ( "testing" "time" + "go4.org/mem" "golang.org/x/time/rate" "tailscale.com/net/nettest" "tailscale.com/types/key" "tailscale.com/types/logger" ) -func newPrivateKey(tb testing.TB) (k key.Private) { - tb.Helper() - if _, err := crand.Read(k[:]); err != nil { - tb.Fatal(err) - } - return -} - func TestClientInfoUnmarshal(t *testing.T) { for i, in := range []string{ `{"Version":5,"MeshKey":"abc"}`, @@ -54,15 +46,15 @@ func TestClientInfoUnmarshal(t *testing.T) { } func TestSendRecv(t *testing.T) { - serverPrivateKey := newPrivateKey(t) + serverPrivateKey := key.NewNode() s := NewServer(serverPrivateKey, t.Logf) defer s.Close() const numClients = 3 - var clientPrivateKeys []key.Private - var clientKeys []key.Public + var clientPrivateKeys []key.NodePrivate + var clientKeys []key.NodePublic for i := 0; i < numClients; i++ { - priv := newPrivateKey(t) + priv := key.NewNode() clientPrivateKeys = append(clientPrivateKeys, priv) clientKeys = append(clientKeys, priv.Public()) } @@ -225,7 +217,7 @@ func TestSendRecv(t *testing.T) { } func TestSendFreeze(t *testing.T) { - serverPrivateKey := newPrivateKey(t) + serverPrivateKey := key.NewNode() s := NewServer(serverPrivateKey, t.Logf) defer s.Close() s.WriteTimeout = 100 * time.Millisecond @@ -238,7 +230,7 @@ func TestSendFreeze(t *testing.T) { // Then cathy stops processing messsages. // That should not interfere with alice talking to bob. - newClient := func(name string, k key.Private) (c *Client, clientConn nettest.Conn) { + newClient := func(name string, k key.NodePrivate) (c *Client, clientConn nettest.Conn) { t.Helper() c1, c2 := nettest.NewConn(name, 1024) go s.Accept(c1, bufio.NewReadWriter(bufio.NewReader(c1), bufio.NewWriter(c1)), name) @@ -252,13 +244,13 @@ func TestSendFreeze(t *testing.T) { return c, c2 } - aliceKey := newPrivateKey(t) + aliceKey := key.NewNode() aliceClient, aliceConn := newClient("alice", aliceKey) - bobKey := newPrivateKey(t) + bobKey := key.NewNode() bobClient, bobConn := newClient("bob", bobKey) - cathyKey := newPrivateKey(t) + cathyKey := key.NewNode() cathyClient, cathyConn := newClient("cathy", cathyKey) var ( @@ -427,7 +419,7 @@ type testServer struct { logf logger.Logf mu sync.Mutex - pubName map[key.Public]string + pubName map[key.NodePublic]string clients map[*testClient]bool } @@ -437,14 +429,14 @@ func (ts *testServer) addTestClient(c *testClient) { ts.clients[c] = true } -func (ts *testServer) addKeyName(k key.Public, name string) { +func (ts *testServer) addKeyName(k key.NodePublic, name string) { ts.mu.Lock() defer ts.mu.Unlock() ts.pubName[k] = name ts.logf("test adding named key %q for %x", name, k) } -func (ts *testServer) keyName(k key.Public) string { +func (ts *testServer) keyName(k key.NodePublic) string { ts.mu.Lock() defer ts.mu.Unlock() if name, ok := ts.pubName[k]; ok { @@ -465,7 +457,7 @@ func (ts *testServer) close(t *testing.T) error { func newTestServer(t *testing.T) *testServer { t.Helper() logf := logger.WithPrefix(t.Logf, "derp-server: ") - s := NewServer(newPrivateKey(t), logf) + s := NewServer(key.NewNode(), logf) s.SetMeshKey("mesh-key") ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -491,7 +483,7 @@ func newTestServer(t *testing.T) *testServer { ln: ln, logf: logf, clients: map[*testClient]bool{}, - pubName: map[key.Public]string{}, + pubName: map[key.NodePublic]string{}, } } @@ -499,18 +491,18 @@ type testClient struct { name string c *Client nc net.Conn - pub key.Public + pub key.NodePublic ts *testServer closed bool } -func newTestClient(t *testing.T, ts *testServer, name string, newClient func(net.Conn, key.Private, logger.Logf) (*Client, error)) *testClient { +func newTestClient(t *testing.T, ts *testServer, name string, newClient func(net.Conn, key.NodePrivate, logger.Logf) (*Client, error)) *testClient { t.Helper() nc, err := net.Dial("tcp", ts.ln.Addr().String()) if err != nil { t.Fatal(err) } - key := newPrivateKey(t) + key := key.NewNode() ts.addKeyName(key.Public(), name) c, err := newClient(nc, key, logger.WithPrefix(t.Logf, "client-"+name+": ")) if err != nil { @@ -528,7 +520,7 @@ func newTestClient(t *testing.T, ts *testServer, name string, newClient func(net } func newRegularClient(t *testing.T, ts *testServer, name string) *testClient { - return newTestClient(t, ts, name, func(nc net.Conn, priv key.Private, logf logger.Logf) (*Client, error) { + return newTestClient(t, ts, name, func(nc net.Conn, priv key.NodePrivate, logf logger.Logf) (*Client, error) { brw := bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc)) c, err := NewClient(priv, nc, brw, logf) if err != nil { @@ -541,7 +533,7 @@ func newRegularClient(t *testing.T, ts *testServer, name string) *testClient { } func newTestWatcher(t *testing.T, ts *testServer, name string) *testClient { - return newTestClient(t, ts, name, func(nc net.Conn, priv key.Private, logf logger.Logf) (*Client, error) { + return newTestClient(t, ts, name, func(nc net.Conn, priv key.NodePrivate, logf logger.Logf) (*Client, error) { brw := bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc)) c, err := NewClient(priv, nc, brw, logf, MeshKey("mesh-key")) if err != nil { @@ -555,9 +547,9 @@ func newTestWatcher(t *testing.T, ts *testServer, name string) *testClient { }) } -func (tc *testClient) wantPresent(t *testing.T, peers ...key.Public) { +func (tc *testClient) wantPresent(t *testing.T, peers ...key.NodePublic) { t.Helper() - want := map[key.Public]bool{} + want := map[key.NodePublic]bool{} for _, k := range peers { want[k] = true } @@ -569,7 +561,7 @@ func (tc *testClient) wantPresent(t *testing.T, peers ...key.Public) { } switch m := m.(type) { case PeerPresentMessage: - got := key.Public(m) + got := key.NodePublic(m) if !want[got] { t.Fatalf("got peer present for %v; want present for %v", tc.ts.keyName(got), logger.ArgWriter(func(bw *bufio.Writer) { for _, pub := range peers { @@ -587,7 +579,7 @@ func (tc *testClient) wantPresent(t *testing.T, peers ...key.Public) { } } -func (tc *testClient) wantGone(t *testing.T, peer key.Public) { +func (tc *testClient) wantGone(t *testing.T, peer key.NodePublic) { t.Helper() m, err := tc.c.recvTimeout(time.Second) if err != nil { @@ -595,7 +587,7 @@ func (tc *testClient) wantGone(t *testing.T, peer key.Public) { } switch m := m.(type) { case PeerGoneMessage: - got := key.Public(m) + got := key.NodePublic(m) if peer != got { t.Errorf("got gone message for %v; want gone for %v", tc.ts.keyName(got), tc.ts.keyName(peer)) } @@ -654,21 +646,24 @@ func TestWatch(t *testing.T) { type testFwd int -func (testFwd) ForwardPacket(key.Public, key.Public, []byte) error { panic("not called in tests") } +func (testFwd) ForwardPacket(key.NodePublic, key.NodePublic, []byte) error { + panic("not called in tests") +} -func pubAll(b byte) (ret key.Public) { - for i := range ret { - ret[i] = b +func pubAll(b byte) (ret key.NodePublic) { + var bs [32]byte + for i := range bs { + bs[i] = b } - return + return key.NodePublicFromRaw32(mem.B(bs[:])) } func TestForwarderRegistration(t *testing.T) { s := &Server{ - clients: make(map[key.Public]clientSet), - clientsMesh: map[key.Public]PacketForwarder{}, + clients: make(map[key.NodePublic]clientSet), + clientsMesh: map[key.NodePublic]PacketForwarder{}, } - want := func(want map[key.Public]PacketForwarder) { + want := func(want map[key.NodePublic]PacketForwarder) { t.Helper() if got := s.clientsMesh; !reflect.DeepEqual(got, want) { t.Fatalf("mismatch\n got: %v\nwant: %v\n", got, want) @@ -687,28 +682,28 @@ func TestForwarderRegistration(t *testing.T) { s.AddPacketForwarder(u1, testFwd(1)) s.AddPacketForwarder(u2, testFwd(2)) - want(map[key.Public]PacketForwarder{ + want(map[key.NodePublic]PacketForwarder{ u1: testFwd(1), u2: testFwd(2), }) // Verify a remove of non-registered forwarder is no-op. s.RemovePacketForwarder(u2, testFwd(999)) - want(map[key.Public]PacketForwarder{ + want(map[key.NodePublic]PacketForwarder{ u1: testFwd(1), u2: testFwd(2), }) // Verify a remove of non-registered user is no-op. s.RemovePacketForwarder(u3, testFwd(1)) - want(map[key.Public]PacketForwarder{ + want(map[key.NodePublic]PacketForwarder{ u1: testFwd(1), u2: testFwd(2), }) // Actual removal. s.RemovePacketForwarder(u2, testFwd(2)) - want(map[key.Public]PacketForwarder{ + want(map[key.NodePublic]PacketForwarder{ u1: testFwd(1), }) @@ -716,7 +711,7 @@ func TestForwarderRegistration(t *testing.T) { wantCounter(&s.multiForwarderCreated, 0) s.AddPacketForwarder(u1, testFwd(100)) s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path - want(map[key.Public]PacketForwarder{ + want(map[key.NodePublic]PacketForwarder{ u1: multiForwarder{ testFwd(1): 1, testFwd(100): 2, @@ -726,7 +721,7 @@ func TestForwarderRegistration(t *testing.T) { // Removing a forwarder in a multi set that doesn't exist; does nothing. s.RemovePacketForwarder(u1, testFwd(55)) - want(map[key.Public]PacketForwarder{ + want(map[key.NodePublic]PacketForwarder{ u1: multiForwarder{ testFwd(1): 1, testFwd(100): 2, @@ -737,7 +732,7 @@ func TestForwarderRegistration(t *testing.T) { // from being a multiForwarder. wantCounter(&s.multiForwarderDeleted, 0) s.RemovePacketForwarder(u1, testFwd(1)) - want(map[key.Public]PacketForwarder{ + want(map[key.NodePublic]PacketForwarder{ u1: testFwd(100), }) wantCounter(&s.multiForwarderDeleted, 1) @@ -750,18 +745,18 @@ func TestForwarderRegistration(t *testing.T) { } s.clients[u1] = singleClient{u1c} s.RemovePacketForwarder(u1, testFwd(100)) - want(map[key.Public]PacketForwarder{ + want(map[key.NodePublic]PacketForwarder{ u1: nil, }) // But once that client disconnects, it should go away. s.unregisterClient(u1c) - want(map[key.Public]PacketForwarder{}) + want(map[key.NodePublic]PacketForwarder{}) // But if it already has a forwarder, it's not removed. s.AddPacketForwarder(u1, testFwd(2)) s.unregisterClient(u1c) - want(map[key.Public]PacketForwarder{ + want(map[key.NodePublic]PacketForwarder{ u1: testFwd(2), }) @@ -770,17 +765,17 @@ func TestForwarderRegistration(t *testing.T) { // from nil to the new one, not a multiForwarder. s.clients[u1] = singleClient{u1c} s.clientsMesh[u1] = nil - want(map[key.Public]PacketForwarder{ + want(map[key.NodePublic]PacketForwarder{ u1: nil, }) s.AddPacketForwarder(u1, testFwd(3)) - want(map[key.Public]PacketForwarder{ + want(map[key.NodePublic]PacketForwarder{ u1: testFwd(3), }) } func TestMetaCert(t *testing.T) { - priv := newPrivateKey(t) + priv := key.NewNode() pub := priv.Public() s := NewServer(priv, t.Logf) @@ -792,7 +787,7 @@ func TestMetaCert(t *testing.T) { if fmt.Sprint(cert.SerialNumber) != fmt.Sprint(ProtocolVersion) { t.Errorf("serial = %v; want %v", cert.SerialNumber, ProtocolVersion) } - if g, w := cert.Subject.CommonName, fmt.Sprintf("derpkey%x", pub[:]); g != w { + if g, w := cert.Subject.CommonName, fmt.Sprintf("derpkey%s", pub.UntypedHexString()); g != w { t.Errorf("CommonName = %q; want %q", g, w) } } @@ -882,10 +877,10 @@ func TestClientSendPong(t *testing.T) { } func TestServerDupClients(t *testing.T) { - serverPriv := newPrivateKey(t) + serverPriv := key.NewNode() var s *Server - clientPriv := newPrivateKey(t) + clientPriv := key.NewNode() clientPub := clientPriv.Public() var c1, c2, c3 *sclient @@ -1141,11 +1136,11 @@ func BenchmarkSendRecv(b *testing.B) { } func benchmarkSendRecvSize(b *testing.B, packetSize int) { - serverPrivateKey := newPrivateKey(b) + serverPrivateKey := key.NewNode() s := NewServer(serverPrivateKey, logger.Discard) defer s.Close() - key := newPrivateKey(b) + key := key.NewNode() clientKey := key.Public() ln, err := net.Listen("tcp", "127.0.0.1:0") @@ -1279,7 +1274,7 @@ func TestClientSendRateLimiting(t *testing.T) { c.setSendRateLimiter(ServerInfoMessage{}) pkt := make([]byte, 1000) - if err := c.send(key.Public{}, pkt); err != nil { + if err := c.send(key.NodePublic{}, pkt); err != nil { t.Fatal(err) } writes1, bytes1 := cw.Stats() @@ -1290,7 +1285,7 @@ func TestClientSendRateLimiting(t *testing.T) { // Flood should all succeed. cw.ResetStats() for i := 0; i < 1000; i++ { - if err := c.send(key.Public{}, pkt); err != nil { + if err := c.send(key.NodePublic{}, pkt); err != nil { t.Fatal(err) } } @@ -1309,7 +1304,7 @@ func TestClientSendRateLimiting(t *testing.T) { TokenBucketBytesBurst: int(bytes1 * 2), }) for i := 0; i < 1000; i++ { - if err := c.send(key.Public{}, pkt); err != nil { + if err := c.send(key.NodePublic{}, pkt); err != nil { t.Fatal(err) } } diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index e5ee7730c..f13ce508e 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -53,7 +53,7 @@ type Client struct { MeshKey string // optional; for trusted clients IsProber bool // optional; for probers to optional declare themselves as such - privateKey key.Private + privateKey key.NodePrivate logf logger.Logf dialer func(ctx context.Context, network, addr string) (net.Conn, error) @@ -71,12 +71,12 @@ type Client struct { netConn io.Closer client *derp.Client connGen int // incremented once per new connection; valid values are >0 - serverPubKey key.Public + serverPubKey key.NodePublic } // NewRegionClient returns a new DERP-over-HTTP client. It connects lazily. // To trigger a connection, use Connect. -func NewRegionClient(privateKey key.Private, logf logger.Logf, getRegion func() *tailcfg.DERPRegion) *Client { +func NewRegionClient(privateKey key.NodePrivate, logf logger.Logf, getRegion func() *tailcfg.DERPRegion) *Client { ctx, cancel := context.WithCancel(context.Background()) c := &Client{ privateKey: privateKey, @@ -96,7 +96,7 @@ func NewNetcheckClient(logf logger.Logf) *Client { // NewClient returns a new DERP-over-HTTP client. It connects lazily. // To trigger a connection, use Connect. -func NewClient(privateKey key.Private, serverURL string, logf logger.Logf) (*Client, error) { +func NewClient(privateKey key.NodePrivate, serverURL string, logf logger.Logf) (*Client, error) { u, err := url.Parse(serverURL) if err != nil { return nil, fmt.Errorf("derphttp.NewClient: %v", err) @@ -127,14 +127,14 @@ func (c *Client) Connect(ctx context.Context) error { // // It only returns a non-zero value once a connection has succeeded // from an earlier call. -func (c *Client) ServerPublicKey() key.Public { +func (c *Client) ServerPublicKey() key.NodePublic { c.mu.Lock() defer c.mu.Unlock() return c.serverPubKey } // SelfPublicKey returns our own public key. -func (c *Client) SelfPublicKey() key.Public { +func (c *Client) SelfPublicKey() key.NodePublic { return c.privateKey.Public() } @@ -315,8 +315,8 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien } }() - var httpConn net.Conn // a TCP conn or a TLS conn; what we speak HTTP to - var serverPub key.Public // or zero if unknown (if not using TLS or TLS middlebox eats it) + var httpConn net.Conn // a TCP conn or a TLS conn; what we speak HTTP to + var serverPub key.NodePublic // or zero if unknown (if not using TLS or TLS middlebox eats it) var serverProtoVersion int if c.useHTTPS() { tlsConn := c.tlsClient(tcpConn, node) @@ -687,7 +687,7 @@ func (c *Client) dialNodeUsingProxy(ctx context.Context, n *tailcfg.DERPNode, pr return proxyConn, nil } -func (c *Client) Send(dstKey key.Public, b []byte) error { +func (c *Client) Send(dstKey key.NodePublic, b []byte) error { client, _, err := c.connect(context.TODO(), "derphttp.Client.Send") if err != nil { return err @@ -698,7 +698,7 @@ func (c *Client) Send(dstKey key.Public, b []byte) error { return err } -func (c *Client) ForwardPacket(from, to key.Public, b []byte) error { +func (c *Client) ForwardPacket(from, to key.NodePublic, b []byte) error { client, _, err := c.connect(context.TODO(), "derphttp.Client.ForwardPacket") if err != nil { return err @@ -779,7 +779,7 @@ func (c *Client) WatchConnectionChanges() error { // ClosePeer asks the server to close target's TCP connection. // // Only trusted connections (using MeshKey) are allowed to use this. -func (c *Client) ClosePeer(target key.Public) error { +func (c *Client) ClosePeer(target key.NodePublic) error { client, _, err := c.connect(context.TODO(), "derphttp.Client.ClosePeer") if err != nil { return err @@ -863,15 +863,15 @@ func (c *Client) closeForReconnect(brokenClient *derp.Client) { var ErrClientClosed = errors.New("derphttp.Client closed") -func parseMetaCert(certs []*x509.Certificate) (serverPub key.Public, serverProtoVersion int) { +func parseMetaCert(certs []*x509.Certificate) (serverPub key.NodePublic, serverProtoVersion int) { for _, cert := range certs { if cn := cert.Subject.CommonName; strings.HasPrefix(cn, "derpkey") { var err error - serverPub, err = key.NewPublicFromHexMem(mem.S(strings.TrimPrefix(cn, "derpkey"))) + serverPub, err = key.ParseNodePublicUntyped(mem.S(strings.TrimPrefix(cn, "derpkey"))) if err == nil && cert.SerialNumber.BitLen() <= 8 { // supports up to version 255 return serverPub, int(cert.SerialNumber.Int64()) } } } - return key.Public{}, 0 + return key.NodePublic{}, 0 } diff --git a/derp/derphttp/derphttp_server.go b/derp/derphttp/derphttp_server.go index e12e72c65..3d2b72e88 100644 --- a/derp/derphttp/derphttp_server.go +++ b/derp/derphttp/derphttp_server.go @@ -51,9 +51,9 @@ func Handler(s *derp.Server) http.Handler { "Upgrade: DERP\r\n"+ "Connection: Upgrade\r\n"+ "Derp-Version: %v\r\n"+ - "Derp-Public-Key: %x\r\n\r\n", + "Derp-Public-Key: %s\r\n\r\n", derp.ProtocolVersion, - pubKey[:]) + pubKey.UntypedHexString()) } s.Accept(netConn, conn, netConn.RemoteAddr().String()) diff --git a/derp/derphttp/derphttp_test.go b/derp/derphttp/derphttp_test.go index d5290b8a9..40cddc4da 100644 --- a/derp/derphttp/derphttp_test.go +++ b/derp/derphttp/derphttp_test.go @@ -18,13 +18,13 @@ import ( ) func TestSendRecv(t *testing.T) { - serverPrivateKey := key.NewPrivate() + serverPrivateKey := key.NewNode() const numClients = 3 - var clientPrivateKeys []key.Private - var clientKeys []key.Public + var clientPrivateKeys []key.NodePrivate + var clientKeys []key.NodePublic for i := 0; i < numClients; i++ { - priv := key.NewPrivate() + priv := key.NewNode() clientPrivateKeys = append(clientPrivateKeys, priv) clientKeys = append(clientKeys, priv.Public()) } diff --git a/derp/derphttp/mesh_client.go b/derp/derphttp/mesh_client.go index e53e9ec5d..a77526727 100644 --- a/derp/derphttp/mesh_client.go +++ b/derp/derphttp/mesh_client.go @@ -27,7 +27,7 @@ import ( // // To force RunWatchConnectionLoop to return quickly, its ctx needs to // be closed, and c itself needs to be closed. -func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key.Public, infoLogf logger.Logf, add, remove func(key.Public)) { +func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key.NodePublic, infoLogf logger.Logf, add, remove func(key.NodePublic)) { if infoLogf == nil { infoLogf = logger.Discard } @@ -36,7 +36,7 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key const statusInterval = 10 * time.Second var ( mu sync.Mutex - present = map[key.Public]bool{} + present = map[key.NodePublic]bool{} loggedConnected = false ) clear := func() { @@ -49,7 +49,7 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key for k := range present { remove(k) } - present = map[key.Public]bool{} + present = map[key.NodePublic]bool{} } lastConnGen := 0 lastStatus := time.Now() @@ -69,7 +69,7 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key }) defer timer.Stop() - updatePeer := func(k key.Public, isPresent bool) { + updatePeer := func(k key.NodePublic, isPresent bool) { if isPresent { add(k) } else { @@ -127,9 +127,9 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key } switch m := m.(type) { case derp.PeerPresentMessage: - updatePeer(key.Public(m), true) + updatePeer(key.NodePublic(m), true) case derp.PeerGoneMessage: - updatePeer(key.Public(m), false) + updatePeer(key.NodePublic(m), false) default: continue } diff --git a/tstest/integration/integration.go b/tstest/integration/integration.go index c62b8fdf0..99552f3de 100644 --- a/tstest/integration/integration.go +++ b/tstest/integration/integration.go @@ -10,7 +10,6 @@ package integration import ( "bytes" - "crypto/rand" "crypto/tls" "encoding/json" "fmt" @@ -126,11 +125,7 @@ func exe() string { func RunDERPAndSTUN(t testing.TB, logf logger.Logf, ipAddress string) (derpMap *tailcfg.DERPMap) { t.Helper() - var serverPrivateKey key.Private - if _, err := rand.Read(serverPrivateKey[:]); err != nil { - t.Fatal(err) - } - d := derp.NewServer(serverPrivateKey, logf) + d := derp.NewServer(key.NewNode(), logf) ln, err := net.Listen("tcp", net.JoinHostPort(ipAddress, "0")) if err != nil { diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 7df3de46b..0b9241cc9 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -1340,7 +1340,7 @@ func (c *Conn) derpWriteChanOfAddr(addr netaddr.IPPort, peer key.Public) chan<- // Note that derphttp.NewRegionClient does not dial the server // so it is safe to do under the mu lock. - dc := derphttp.NewRegionClient(c.privateKey, c.logf, func() *tailcfg.DERPRegion { + dc := derphttp.NewRegionClient(key.NodePrivateFromRaw32(mem.B(c.privateKey[:])), c.logf, func() *tailcfg.DERPRegion { if c.connCtx.Err() != nil { // If we're closing, don't try to acquire the lock. // We might already be in Conn.Close and the Lock would deadlock. @@ -1539,15 +1539,15 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr netaddr.IPPort, d case derp.ReceivedPacket: pkt = m res.n = len(m.Data) - res.src = m.Source + res.src = m.Source.AsPublic() if logDerpVerbose { c.logf("magicsock: got derp-%v packet: %q", regionID, m.Data) } // If this is a new sender we hadn't seen before, remember it and // register a route for this peer. - if _, ok := peerPresent[m.Source]; !ok { - peerPresent[m.Source] = true - c.addDerpPeerRoute(m.Source, regionID, dc) + if _, ok := peerPresent[res.src]; !ok { + peerPresent[res.src] = true + c.addDerpPeerRoute(res.src, regionID, dc) } case derp.PingMessage: // Best effort reply to the ping. @@ -1601,7 +1601,7 @@ func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan case <-ctx.Done(): return case wr := <-ch: - err := dc.Send(wr.pubKey, wr.b) + err := dc.Send(key.NodePublicFromRaw32(mem.B(wr.pubKey[:])), wr.b) if err != nil { c.logf("magicsock: derp.Send(%v): %v", wr.addr, err) } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 852087d9f..d51fb66e1 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -81,11 +81,7 @@ func (c *Conn) WaitReady(t testing.TB) { } func runDERPAndStun(t *testing.T, logf logger.Logf, l nettype.PacketListener, stunIP netaddr.IP) (derpMap *tailcfg.DERPMap, cleanup func()) { - var serverPrivateKey key.Private - if _, err := crand.Read(serverPrivateKey[:]); err != nil { - t.Fatal(err) - } - d := derp.NewServer(serverPrivateKey, logf) + d := derp.NewServer(key.NewNode(), logf) httpsrv := httptest.NewUnstartedServer(derphttp.Handler(d)) httpsrv.Config.ErrorLog = logger.StdLogger(logf)