derp: change the protocol framing to always include a length

Addresses one of crawshaw's TODOs.

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/86/head
Brad Fitzpatrick 5 years ago committed by Brad Fitzpatrick
parent c47f907a27
commit f029c4c82d

@ -15,16 +15,27 @@ package derp
import ( import (
"bufio" "bufio"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"time" "time"
) )
// magic is the derp magic number, sent on the wire as a uint32. // magic is the DERP magic number, sent in the frameServerKey frame
// It's "DERP" with a non-ASCII high-bit. // upon initial connection.
const magic = 0x44c55250 const magic = "DERP🔑" // 8 bytes: 0x44 45 52 50 f0 9f 94 91
// frameType is the one byte frame type header in frame headers. const (
nonceLen = 24
keyLen = 32
maxInfoLen = 1 << 20
keepAlive = 60 * time.Second
)
// frameType is the one byte frame type at the beginning of the frame
// header. The second field is a big-endian uint32 describing the
// length of the remaining frame (not including the initial 5 bytes).
type frameType byte type frameType byte
/* /*
@ -32,69 +43,110 @@ Protocol flow:
Login: Login:
* client connects * client connects
* server sends magic: [be_uint32(magic)] * server sends frameServerKey
* server sends typeServerKey frame: 1 byte typeServerKey + 32 bytes of public key * client sends frameClientInfo
* client sends: (with no frameType) * server sends frameServerInfo
- 32 bytes client public key
- 24 bytes nonce
- be_uint32 length of naclbox (capped at 256k)
- that many bytes of naclbox
* (server verifies client is authorized)
* server sends typeServerInfo frame byte + 24 byte nonce + beu32 len + naclbox(json)
Steady state: Steady state:
* server occasionally sends typeKeepAlive. (One byte only) * server occasionally sends frameKeepAlive
* client sends typeSendPacket byte + 32 byte dest pub key + beu32 packet len + packet bytes * client sends frameSendPacket
* server then sends typeRecvPacket byte + beu32 packet len + packet bytes to recipient conn * server then sends frameRecvPacket to recipient
TODO(bradfitz): require pings to be acknowledged; copy http2 PING frame w/ ping payload
*/ */
const ( const (
typeServerKey = frameType(0x01) frameServerKey = frameType(0x01) // 8B magic + 32B public key + (0+ bytes future use)
typeServerInfo = frameType(0x02) frameClientInfo = frameType(0x02) // 32B pub key + 24B nonce + naclbox(json)
typeSendPacket = frameType(0x03) frameServerInfo = frameType(0x03) // 24B nonce + naclbox(json)
typeRecvPacket = frameType(0x04) frameSendPacket = frameType(0x04) // 32B dest pub key + packet bytes
typeKeepAlive = frameType(0x05) frameRecvPacket = frameType(0x05) // packet bytes
frameKeepAlive = frameType(0x06) // no payload, no-op (to be replaced with ping/pong)
) )
func (b frameType) Write(w io.ByteWriter) error { var bin = binary.BigEndian
return w.WriteByte(byte(b))
func writeUint32(bw *bufio.Writer, v uint32) error {
var b [4]byte
bin.PutUint32(b[:], v)
_, err := bw.Write(b[:])
return err
} }
const keepAlive = 60 * time.Second func readUint32(br *bufio.Reader) (uint32, error) {
b := make([]byte, 4)
if _, err := io.ReadFull(br, b); err != nil {
return 0, err
}
return bin.Uint32(b), nil
}
var bin = binary.BigEndian func readFrameTypeHeader(br *bufio.Reader, wantType frameType) (frameLen uint32, err error) {
gotType, frameLen, err := readFrameHeader(br)
if err == nil && wantType != gotType {
err = fmt.Errorf("bad frame type 0x%X, want 0x%X", gotType, wantType)
}
return frameLen, err
}
const oneMB = 1 << 20 func readFrameHeader(br *bufio.Reader) (t frameType, frameLen uint32, err error) {
tb, err := br.ReadByte()
if err != nil {
return 0, 0, err
}
frameLen, err = readUint32(br)
if err != nil {
return 0, 0, err
}
return frameType(tb), frameLen, nil
}
func readType(r *bufio.Reader, t frameType) error { // readFrame reads a frame header and then reads its payload into
packetType, err := r.ReadByte() // b[:frameLen].
//
// If the frame header length is greater than maxSize, readFrame returns
// an error after reading the frame header.
//
// If the frame is less than maxSize but greater than len(b), len(b)
// bytes are read, err will be io.ErrShortBuffer, and frameLen and t
// will both be set. That is, callers need to explicitly handle when
// they get more data than expected.
func readFrame(br *bufio.Reader, maxSize uint32, b []byte) (t frameType, frameLen uint32, err error) {
t, frameLen, err = readFrameHeader(br)
if err != nil { if err != nil {
return err return 0, 0, err
}
if frameLen > maxSize {
return 0, 0, fmt.Errorf("frame header size %d exceeds reader limit of %d", frameLen, maxSize)
}
n, err := io.ReadFull(br, b[:frameLen])
if err != nil {
return 0, 0, err
} }
if frameType(packetType) != t { remain := frameLen - uint32(n)
return fmt.Errorf("bad packet type 0x%X, want 0x%X", packetType, t) if remain > 0 {
if _, err := io.CopyN(ioutil.Discard, br, int64(remain)); err != nil {
return 0, 0, err
} }
return nil err = io.ErrShortBuffer
}
return t, frameLen, err
} }
func putUint32(w io.Writer, v uint32) error { func writeFrameHeader(bw *bufio.Writer, t frameType, frameLen uint32) error {
var b [4]byte if err := bw.WriteByte(byte(t)); err != nil {
bin.PutUint32(b[:], v)
_, err := w.Write(b[:])
return err return err
} }
return writeUint32(bw, frameLen)
}
func readUint32(r io.Reader, maxVal uint32) (uint32, error) { // writeFrame writes a complete frame & flushes it.
b := make([]byte, 4) func writeFrame(bw *bufio.Writer, t frameType, b []byte) error {
if _, err := io.ReadFull(r, b); err != nil { if len(b) > 10<<20 {
return 0, err return errors.New("unreasonably large frame write")
}
if err := writeFrameHeader(bw, t, uint32(len(b))); err != nil {
return err
} }
val := bin.Uint32(b) if _, err := bw.Write(b); err != nil {
if val > maxVal { return err
return 0, fmt.Errorf("uint32 %d exceeds limit %d", val, maxVal)
} }
return val, nil return bw.Flush()
} }

@ -8,6 +8,7 @@ import (
"bufio" "bufio"
crand "crypto/rand" crand "crypto/rand"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -54,34 +55,40 @@ func NewClient(privateKey key.Private, nc net.Conn, brw *bufio.ReadWriter, logf
} }
func (c *Client) recvServerKey() error { func (c *Client) recvServerKey() error {
gotMagic, err := readUint32(c.br, 0xffffffff) var buf [40]byte
if err != nil { t, flen, err := readFrame(c.br, 1<<10, buf[:])
return err if err == io.ErrShortBuffer {
} // For future-proofing, allow server to send more in its greeting.
if gotMagic != magic { err = nil
return fmt.Errorf("bad magic %x, want %x", gotMagic, magic)
} }
if err := readType(c.br, typeServerKey); err != nil { if err != nil {
return err return err
} }
if _, err := io.ReadFull(c.br, c.serverKey[:]); err != nil { if flen < uint32(len(buf)) || t != frameServerKey || string(buf[:len(magic)]) != magic {
return err return errors.New("invalid server greeting")
} }
copy(c.serverKey[:], buf[len(magic):])
return nil return nil
} }
func (c *Client) recvServerInfo() (*serverInfo, error) { func (c *Client) recvServerInfo() (*serverInfo, error) {
if err := readType(c.br, typeServerInfo); err != nil { fl, err := readFrameTypeHeader(c.br, frameServerInfo)
if err != nil {
return nil, err return nil, err
} }
var nonce [24]byte const maxLength = nonceLen + maxInfoLen
if fl < nonceLen {
return nil, fmt.Errorf("short serverInfo frame")
}
if fl > maxLength {
return nil, fmt.Errorf("long serverInfo frame")
}
// TODO: add a read-nonce-and-box helper
var nonce [nonceLen]byte
if _, err := io.ReadFull(c.br, nonce[:]); err != nil { if _, err := io.ReadFull(c.br, nonce[:]); err != nil {
return nil, fmt.Errorf("nonce: %v", err) return nil, fmt.Errorf("nonce: %v", err)
} }
msgLen, err := readUint32(c.br, oneMB) msgLen := fl - nonceLen
if err != nil {
return nil, fmt.Errorf("msglen: %v", err)
}
msgbox := make([]byte, msgLen) msgbox := make([]byte, msgLen)
if _, err := io.ReadFull(c.br, msgbox); err != nil { if _, err := io.ReadFull(c.br, msgbox); err != nil {
return nil, fmt.Errorf("msgbox: %v", err) return nil, fmt.Errorf("msgbox: %v", err)
@ -98,49 +105,43 @@ func (c *Client) recvServerInfo() (*serverInfo, error) {
} }
func (c *Client) sendClientKey() error { func (c *Client) sendClientKey() error {
var nonce [24]byte var nonce [nonceLen]byte
if _, err := crand.Read(nonce[:]); err != nil { if _, err := crand.Read(nonce[:]); err != nil {
return err return err
} }
msg := []byte("{}") // no clientInfo for now msg := []byte("{}") // no clientInfo for now
msgbox := box.Seal(nil, msg, &nonce, c.serverKey.B32(), c.privateKey.B32()) msgbox := box.Seal(nil, msg, &nonce, c.serverKey.B32(), c.privateKey.B32())
if _, err := c.bw.Write(c.publicKey[:]); err != nil { buf := make([]byte, 0, nonceLen+keyLen+len(msgbox))
return err buf = append(buf, c.publicKey[:]...)
} buf = append(buf, nonce[:]...)
if _, err := c.bw.Write(nonce[:]); err != nil { buf = append(buf, msgbox...)
return err return writeFrame(c.bw, frameClientInfo, buf)
}
if err := putUint32(c.bw, uint32(len(msgbox))); err != nil {
return err
}
if _, err := c.bw.Write(msgbox); err != nil {
return err
}
return c.bw.Flush()
} }
func (c *Client) Send(dstKey key.Public, msg []byte) (err error) { // Send sends a packet to the Tailscale node identified by dstKey.
//
// It is an error if the packet is larger than 64KB.
func (c *Client) Send(dstKey key.Public, pkt []byte) error { return c.send(dstKey, pkt) }
func (c *Client) send(dstKey key.Public, pkt []byte) (ret error) {
defer func() { defer func() {
if err != nil { if ret != nil {
err = fmt.Errorf("derp.Send: %v", err) ret = fmt.Errorf("derp.Send: %v", ret)
} }
}() }()
if err := typeSendPacket.Write(c.bw); err != nil { if len(pkt) > 64<<10 {
return err return fmt.Errorf("packet too big: %d", len(pkt))
} }
if _, err := c.bw.Write(dstKey[:]); err != nil {
if err := writeFrameHeader(c.bw, frameSendPacket, uint32(len(dstKey)+len(pkt))); err != nil {
return err return err
} }
msgLen := uint32(len(msg)) if _, err := c.bw.Write(dstKey[:]); err != nil {
if int(msgLen) != len(msg) {
return fmt.Errorf("packet too big: %d", len(msg))
}
if err := putUint32(c.bw, msgLen); err != nil {
return err return err
} }
if _, err := c.bw.Write(msg); err != nil { if _, err := c.bw.Write(pkt); err != nil {
return err return err
} }
return c.bw.Flush() return c.bw.Flush()
@ -160,34 +161,21 @@ func (c *Client) Recv(b []byte) (n int, err error) {
} }
}() }()
loop:
for { for {
c.nc.SetReadDeadline(time.Now().Add(120 * time.Second)) c.nc.SetReadDeadline(time.Now().Add(120 * time.Second))
typ, err := c.br.ReadByte() t, n, err := readFrame(c.br, 1<<20, b)
if err != nil { if err != nil {
return 0, err return 0, err
} }
switch frameType(typ) { switch t {
case typeKeepAlive:
continue
case typeRecvPacket:
break loop
default: default:
return 0, fmt.Errorf("derp.Recv: unknown packet type 0x%X", typ) continue
} case frameKeepAlive:
} // TODO: eventually we'll have server->client pings that
// require ack pongs.
packetLen, err := readUint32(c.br, oneMB) continue
if err != nil { case frameRecvPacket:
return 0, err return int(n), nil
}
if int(packetLen) > len(b) {
// TODO(crawshaw): discard the packet
return 0, io.ErrShortBuffer
} }
b = b[:packetLen]
if _, err := io.ReadFull(c.br, b); err != nil {
return 0, err
} }
return int(packetLen), nil
} }

@ -4,7 +4,6 @@
package derp package derp
// TODO(crawshaw): revise protocol so unknown type packets have a predictable length for skipping.
// TODO(crawshaw): send srcKey with packets to clients? // TODO(crawshaw): send srcKey with packets to clients?
// TODO(crawshaw): with predefined serverKey in clients and HMAC on packets we could skip TLS // TODO(crawshaw): with predefined serverKey in clients and HMAC on packets we could skip TLS
@ -13,6 +12,7 @@ import (
"context" "context"
crand "crypto/rand" crand "crypto/rand"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"math/big" "math/big"
@ -107,7 +107,7 @@ func (s *Server) registerClient(c *sclient) {
s.clients[c.key] = c s.clients[c.key] = c
} }
// unregisterClient // unregisterClient removes a client from the server.
func (s *Server) unregisterClient(c *sclient) { func (s *Server) unregisterClient(c *sclient) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@ -158,9 +158,18 @@ func (s *Server) accept(nc net.Conn, brw *bufio.ReadWriter) error {
go s.sendClientKeepAlives(ctx, c) go s.sendClientKeepAlives(ctx, c)
for { for {
dstKey, contents, err := s.recvPacket(c.br) ft, fl, err := readFrameHeader(c.br)
if err != nil { if err != nil {
return fmt.Errorf("client %x: recv: %v", c.key, err) return fmt.Errorf("client %x: readFrameHeader: %v", c.key, err)
}
if ft != frameSendPacket {
// TODO: nothing else yet supported
return fmt.Errorf("client %x: unsupported frame %v", c.key, ft)
}
dstKey, contents, err := s.recvPacket(c.br, fl)
if err != nil {
return fmt.Errorf("client %x: recvPacket: %v", c.key, err)
} }
s.mu.Lock() s.mu.Lock()
@ -202,16 +211,10 @@ func (s *Server) verifyClient(clientKey key.Public, info *sclientInfo) error {
} }
func (s *Server) sendServerKey(bw *bufio.Writer) error { func (s *Server) sendServerKey(bw *bufio.Writer) error {
if err := putUint32(bw, magic); err != nil { buf := make([]byte, 0, len(magic)+len(s.publicKey))
return err buf = append(buf, magic...)
} buf = append(buf, s.publicKey[:]...)
if err := typeServerKey.Write(bw); err != nil { return writeFrame(bw, frameServerKey, buf)
return err
}
if _, err := bw.Write(s.publicKey[:]); err != nil {
return err
}
return bw.Flush()
} }
func (s *Server) sendServerInfo(bw *bufio.Writer, clientKey key.Public) error { func (s *Server) sendServerInfo(bw *bufio.Writer, clientKey key.Public) error {
@ -221,25 +224,35 @@ func (s *Server) sendServerInfo(bw *bufio.Writer, clientKey key.Public) error {
} }
msg := []byte("{}") // no serverInfo for now msg := []byte("{}") // no serverInfo for now
msgbox := box.Seal(nil, msg, &nonce, clientKey.B32(), s.privateKey.B32()) msgbox := box.Seal(nil, msg, &nonce, clientKey.B32(), s.privateKey.B32())
if err := writeFrameHeader(bw, frameServerInfo, nonceLen+uint32(len(msgbox))); err != nil {
if err := typeServerInfo.Write(bw); err != nil {
return err return err
} }
if _, err := bw.Write(nonce[:]); err != nil { if _, err := bw.Write(nonce[:]); err != nil {
return err return err
} }
if err := putUint32(bw, uint32(len(msgbox))); err != nil {
return err
}
if _, err := bw.Write(msgbox); err != nil { if _, err := bw.Write(msgbox); err != nil {
return err return err
} }
return bw.Flush() return bw.Flush()
} }
// recvClientKey reads the client's hello (its proof of identity) upon its initial connection. // recvClientKey reads the frameClientInfo frame from the client (its
// It should be considered especially untrusted at this point. // proof of identity) upon its initial connection. It should be
// considered especially untrusted at this point.
func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.Public, info *sclientInfo, err error) { func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.Public, info *sclientInfo, err error) {
fl, err := readFrameTypeHeader(br, frameClientInfo)
if err != nil {
return key.Public{}, nil, err
}
const minLen = keyLen + nonceLen
if fl < minLen {
return key.Public{}, nil, errors.New("short client info")
}
// We don't trust the client at all yet, so limit its input size to limit
// things like JSON resource exhausting (http://github.com/golang/go/issues/31789).
if fl > 256<<10 {
return key.Public{}, nil, errors.New("long client info")
}
if _, err := io.ReadFull(br, clientKey[:]); err != nil { if _, err := io.ReadFull(br, clientKey[:]); err != nil {
return key.Public{}, nil, err return key.Public{}, nil, err
} }
@ -247,12 +260,7 @@ func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.Public, info *sc
if _, err := io.ReadFull(br, nonce[:]); err != nil { if _, err := io.ReadFull(br, nonce[:]); err != nil {
return key.Public{}, nil, fmt.Errorf("nonce: %v", err) return key.Public{}, nil, fmt.Errorf("nonce: %v", err)
} }
// We don't trust the client at all yet, so limit its input size to limit msgLen := int(fl - minLen)
// things like JSON resource exhausting (http://github.com/golang/go/issues/31789).
msgLen, err := readUint32(br, 256<<10)
if err != nil {
return key.Public{}, nil, fmt.Errorf("msglen: %v", err)
}
msgbox := make([]byte, msgLen) msgbox := make([]byte, msgLen)
if _, err := io.ReadFull(br, msgbox); err != nil { if _, err := io.ReadFull(br, msgbox); err != nil {
return key.Public{}, nil, fmt.Errorf("msgbox: %v", err) return key.Public{}, nil, fmt.Errorf("msgbox: %v", err)
@ -269,10 +277,7 @@ func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.Public, info *sc
} }
func (s *Server) sendPacket(bw *bufio.Writer, srcKey key.Public, contents []byte) error { func (s *Server) sendPacket(bw *bufio.Writer, srcKey key.Public, contents []byte) error {
if err := typeRecvPacket.Write(bw); err != nil { if err := writeFrameHeader(bw, frameRecvPacket, uint32(len(contents))); err != nil {
return err
}
if err := putUint32(bw, uint32(len(contents))); err != nil {
return err return err
} }
if _, err := bw.Write(contents); err != nil { if _, err := bw.Write(contents); err != nil {
@ -281,17 +286,14 @@ func (s *Server) sendPacket(bw *bufio.Writer, srcKey key.Public, contents []byte
return bw.Flush() return bw.Flush()
} }
func (s *Server) recvPacket(br *bufio.Reader) (dstKey key.Public, contents []byte, err error) { func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.Public, contents []byte, err error) {
if err := readType(br, typeSendPacket); err != nil { if frameLen < keyLen {
return key.Public{}, nil, err return key.Public{}, nil, errors.New("short send packet frame")
} }
if _, err := io.ReadFull(br, dstKey[:]); err != nil { if _, err := io.ReadFull(br, dstKey[:]); err != nil {
return key.Public{}, nil, err return key.Public{}, nil, err
} }
packetLen, err := readUint32(br, oneMB) packetLen := frameLen - keyLen
if err != nil {
return key.Public{}, nil, err
}
contents = make([]byte, packetLen) contents = make([]byte, packetLen)
if _, err := io.ReadFull(br, contents); err != nil { if _, err := io.ReadFull(br, contents); err != nil {
return key.Public{}, nil, err return key.Public{}, nil, err
@ -335,7 +337,7 @@ func (c *sclient) keepAliveLoop(ctx context.Context) error {
c.keepAliveTimer.Reset(keepAlive + jitter) c.keepAliveTimer.Reset(keepAlive + jitter)
case <-c.keepAliveTimer.C: case <-c.keepAliveTimer.C:
c.mu.Lock() c.mu.Lock()
err := typeKeepAlive.Write(c.bw) err := writeFrame(c.bw, frameKeepAlive, nil)
if err == nil { if err == nil {
err = c.bw.Flush() err = c.bw.Flush()
} }

Loading…
Cancel
Save