diff --git a/derp/derp_server.go b/derp/derp_server.go index 369cd69d3..7ff6a9fa0 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -86,7 +86,7 @@ func (s *Server) Accept(netConn net.Conn, conn *bufio.ReadWriter) { func (s *Server) accept(netConn net.Conn, conn *bufio.ReadWriter) error { netConn.SetDeadline(time.Now().Add(10 * time.Second)) - if err := s.sendServerKey(conn); err != nil { + if err := s.sendServerKey(conn.Writer); err != nil { return fmt.Errorf("send server key: %v", err) } netConn.SetDeadline(time.Now().Add(10 * time.Second)) @@ -150,7 +150,7 @@ func (s *Server) accept(netConn net.Conn, conn *bufio.ReadWriter) error { } for { - dstKey, contents, err := s.recvPacket(c.conn) + dstKey, contents, err := s.recvPacket(c.conn.Reader) if err != nil { return fmt.Errorf("client %x: recv: %v", c.key, err) } @@ -187,17 +187,17 @@ func (s *Server) verifyClient(clientKey key.Public, info *clientInfo) error { return nil } -func (s *Server) sendServerKey(conn *bufio.ReadWriter) error { - if err := putUint32(conn, magic); err != nil { +func (s *Server) sendServerKey(bw *bufio.Writer) error { + if err := putUint32(bw, magic); err != nil { return err } - if err := typeServerKey.Write(conn); err != nil { + if err := typeServerKey.Write(bw); err != nil { return err } - if _, err := conn.Write(s.publicKey[:]); err != nil { + if _, err := bw.Write(s.publicKey[:]); err != nil { return err } - return conn.Flush() + return bw.Flush() } func (s *Server) sendServerInfo(conn *bufio.ReadWriter, clientKey key.Public) error { @@ -223,20 +223,24 @@ func (s *Server) sendServerInfo(conn *bufio.ReadWriter, clientKey key.Public) er return conn.Flush() } -func (s *Server) recvClientKey(conn *bufio.ReadWriter) (clientKey key.Public, info *clientInfo, err error) { - if _, err := io.ReadFull(conn, clientKey[:]); err != nil { +// recvClientKey reads the client's hello (its proof of identity) upon its initial connection. +// It should be considered especially untrusted at this point. +func (s *Server) recvClientKey(br *bufio.ReadWriter) (clientKey key.Public, info *clientInfo, err error) { + if _, err := io.ReadFull(br, clientKey[:]); err != nil { return key.Public{}, nil, err } var nonce [24]byte - if _, err := io.ReadFull(conn, nonce[:]); err != nil { + if _, err := io.ReadFull(br, nonce[:]); err != nil { return key.Public{}, nil, fmt.Errorf("nonce: %v", err) } - msgLen, err := readUint32(conn, oneMB) + // We don't trust the client at all yet, so limit its input size to limit + // things like JSON resource exhausting (http://github.com/golang/go/issues/31789). + msgLen, err := readUint32(br, 256<<10) if err != nil { return key.Public{}, nil, fmt.Errorf("msglen: %v", err) } msgbox := make([]byte, msgLen) - if _, err := io.ReadFull(conn, msgbox); err != nil { + if _, err := io.ReadFull(br, msgbox); err != nil { return key.Public{}, nil, fmt.Errorf("msgbox: %v", err) } msg, ok := box.Open(nil, msgbox, &nonce, (*[32]byte)(&clientKey), s.privateKey.B32()) @@ -263,19 +267,19 @@ func (s *Server) sendPacket(bw *bufio.Writer, srcKey key.Public, contents []byte return bw.Flush() } -func (s *Server) recvPacket(conn *bufio.ReadWriter) (dstKey key.Public, contents []byte, err error) { - if err := readType(conn.Reader, typeSendPacket); err != nil { +func (s *Server) recvPacket(br *bufio.Reader) (dstKey key.Public, contents []byte, err error) { + if err := readType(br, typeSendPacket); err != nil { return key.Public{}, nil, err } - if _, err := io.ReadFull(conn, dstKey[:]); err != nil { + if _, err := io.ReadFull(br, dstKey[:]); err != nil { return key.Public{}, nil, err } - packetLen, err := readUint32(conn.Reader, oneMB) + packetLen, err := readUint32(br, oneMB) if err != nil { return key.Public{}, nil, err } contents = make([]byte, packetLen) - if _, err := io.ReadFull(conn, contents); err != nil { + if _, err := io.ReadFull(br, contents); err != nil { return key.Public{}, nil, err } return dstKey, contents, nil diff --git a/derp/derphttp/derphttp_test.go b/derp/derphttp/derphttp_test.go index 4354ce2d8..73885681d 100644 --- a/derp/derphttp/derphttp_test.go +++ b/derp/derphttp/derphttp_test.go @@ -5,6 +5,7 @@ package derphttp import ( + "context" crand "crypto/rand" "crypto/tls" "net" @@ -77,6 +78,9 @@ func TestSendRecv(t *testing.T) { if err != nil { t.Fatalf("client %d: %v", i, err) } + if err := c.Connect(context.Background()); err != nil { + t.Fatalf("client %d Connect: %v", i, err) + } clients = append(clients, c) recvChs = append(recvChs, make(chan []byte))