mirror of https://github.com/tailscale/tailscale/
cmd/k8s-operator,k8s-operator/sessionrecording: support recording kubectl exec sessions over WebSockets (#12947)
cmd/k8s-operator,k8s-operator/sessionrecording: support recording WebSocket sessions Kubernetes currently supports two streaming protocols, SPDY and WebSockets. WebSockets are replacing SPDY, see https://github.com/kubernetes/enhancements/issues/4006. We were currently only supporting SPDY, erroring out if session was not SPDY and relying on the kube's built-in SPDY fallback. This PR: - adds support for parsing contents of 'kubectl exec' sessions streamed over WebSockets - adds logic to distinguish 'kubectl exec' requests for a SPDY/WebSockets sessions and call the relevant handler Updates tailscale/corp#19821 Signed-off-by: Irbe Krumina <irbe@tailscale.com> Co-authored-by: Tom Proctor <tomhjp@users.noreply.github.com>pull/13142/head
parent
4c2e978f1e
commit
a15ff1bade
@ -1,20 +0,0 @@
|
|||||||
// Copyright (c) Tailscale Inc & AUTHORS
|
|
||||||
// SPDX-License-Identifier: BSD-3-Clause
|
|
||||||
|
|
||||||
//go:build !plan9
|
|
||||||
|
|
||||||
// Package conn contains shared interface for the hijacked
|
|
||||||
// connection of a 'kubectl exec' session that is being recorded.
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import "net"
|
|
||||||
|
|
||||||
type Conn interface {
|
|
||||||
net.Conn
|
|
||||||
// Fail can be called to set connection state to failed. By default any
|
|
||||||
// bytes left over in write buffer are forwarded to the intended
|
|
||||||
// destination when the connection is being closed except for when the
|
|
||||||
// connection state is failed- so set the state to failed when erroring
|
|
||||||
// out and failure policy is to fail closed.
|
|
||||||
Fail()
|
|
||||||
}
|
|
@ -0,0 +1,301 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
//go:build !plan9
|
||||||
|
|
||||||
|
// package ws has functionality to parse 'kubectl exec' sessions streamed using
|
||||||
|
// WebSocket protocol.
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"k8s.io/apimachinery/pkg/util/remotecommand"
|
||||||
|
"tailscale.com/k8s-operator/sessionrecording/tsrecorder"
|
||||||
|
"tailscale.com/sessionrecording"
|
||||||
|
"tailscale.com/util/multierr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// New wraps the provided network connection and returns a connection whose reads and writes will get triggered as data is received on the hijacked connection.
|
||||||
|
// The connection must be a hijacked connection for a 'kubectl exec' session using WebSocket protocol and a *.channel.k8s.io subprotocol.
|
||||||
|
// The hijacked connection is used to transmit *.channel.k8s.io streams between Kubernetes client ('kubectl') and the destination proxy controlled by Kubernetes.
|
||||||
|
// Data read from the underlying network connection is data sent via one of the streams from the client to the container.
|
||||||
|
// Data written to the underlying connection is data sent from the container to the client.
|
||||||
|
// We parse the data and send everything for the STDOUT/STDERR streams to the configured tsrecorder as an asciinema recording with the provided header.
|
||||||
|
// https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/4006-transition-spdy-to-websockets#proposal-new-remotecommand-sub-protocol-version---v5channelk8sio
|
||||||
|
func New(c net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, log *zap.SugaredLogger) net.Conn {
|
||||||
|
return &conn{
|
||||||
|
Conn: c,
|
||||||
|
rec: rec,
|
||||||
|
ch: ch,
|
||||||
|
log: log,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// conn is a wrapper around net.Conn. It reads the bytestream
|
||||||
|
// for a 'kubectl exec' session, sends session recording data to the configured
|
||||||
|
// recorder and forwards the raw bytes to the original destination.
|
||||||
|
// A new conn is created per session.
|
||||||
|
// conn only knows to how to read a 'kubectl exec' session that is streamed using WebSocket protocol.
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc6455
|
||||||
|
type conn struct {
|
||||||
|
net.Conn
|
||||||
|
// rec knows how to send data to a tsrecorder instance.
|
||||||
|
rec *tsrecorder.Client
|
||||||
|
// ch is the asiinema CastHeader for a session.
|
||||||
|
ch sessionrecording.CastHeader
|
||||||
|
log *zap.SugaredLogger
|
||||||
|
|
||||||
|
rmu sync.Mutex // sequences reads
|
||||||
|
// currentReadMsg contains parsed contents of a websocket binary data message that
|
||||||
|
// is currently being read from the underlying net.Conn.
|
||||||
|
currentReadMsg *message
|
||||||
|
// readBuf contains bytes for a currently parsed binary data message
|
||||||
|
// read from the underlying conn. If the message is masked, it is
|
||||||
|
// unmasked in place, so having this buffer allows us to avoid modifying
|
||||||
|
// the original byte array.
|
||||||
|
readBuf bytes.Buffer
|
||||||
|
|
||||||
|
wmu sync.Mutex // sequences writes
|
||||||
|
writeCastHeaderOnce sync.Once
|
||||||
|
closed bool // connection is closed
|
||||||
|
// writeBuf contains bytes for a currently parsed binary data message
|
||||||
|
// being written to the underlying conn. If the message is masked, it is
|
||||||
|
// unmasked in place, so having this buffer allows us to avoid modifying
|
||||||
|
// the original byte array.
|
||||||
|
writeBuf bytes.Buffer
|
||||||
|
// currentWriteMsg contains parsed contents of a websocket binary data message that
|
||||||
|
// is currently being written to the underlying net.Conn.
|
||||||
|
currentWriteMsg *message
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads bytes from the original connection and parses them as websocket
|
||||||
|
// message fragments.
|
||||||
|
// Bytes read from the original connection are the bytes sent from the Kubernetes client (kubectl) to the destination container via kubelet.
|
||||||
|
|
||||||
|
// If the message is for the resize stream, sets the width
|
||||||
|
// and height of the CastHeader for this connection.
|
||||||
|
// The fragment can be incomplete.
|
||||||
|
func (c *conn) Read(b []byte) (int, error) {
|
||||||
|
c.rmu.Lock()
|
||||||
|
defer c.rmu.Unlock()
|
||||||
|
n, err := c.Conn.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
// It seems that we sometimes get a wrapped io.EOF, but the
|
||||||
|
// caller checks for io.EOF with ==.
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
err = io.EOF
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
c.log.Debug("[unexpected] Read called for 0 length bytes")
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
typ := messageType(opcode(b))
|
||||||
|
if (typ == noOpcode && c.readMsgIsIncomplete()) || c.readBufHasIncompleteFragment() { // subsequent fragment
|
||||||
|
if typ, err = c.curReadMsgType(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A control message can not be fragmented and we are not interested in
|
||||||
|
// these messages. Just return.
|
||||||
|
if isControlMessage(typ) {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// The only data message type that Kubernetes supports is binary message.
|
||||||
|
// If we received another message type, return and let the API server close the connection.
|
||||||
|
// https://github.com/kubernetes/client-go/blob/release-1.30/tools/remotecommand/websocket.go#L281
|
||||||
|
if typ != binaryMessage {
|
||||||
|
c.log.Infof("[unexpected] received a data message with a type that is not binary message type %v", typ)
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
readMsg := &message{typ: typ} // start a new message...
|
||||||
|
// ... or pick up an already started one if the previous fragment was not final.
|
||||||
|
if c.readMsgIsIncomplete() || c.readBufHasIncompleteFragment() {
|
||||||
|
readMsg = c.currentReadMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := c.readBuf.Write(b[:n]); err != nil {
|
||||||
|
return 0, fmt.Errorf("[unexpected] error writing message contents to read buffer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := readMsg.Parse(c.readBuf.Bytes(), c.log)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("error parsing message: %v", err)
|
||||||
|
}
|
||||||
|
if !ok { // incomplete fragment
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
c.readBuf.Next(len(readMsg.raw))
|
||||||
|
|
||||||
|
if readMsg.isFinalized {
|
||||||
|
// Stream IDs for websocket streams are static.
|
||||||
|
// https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L218
|
||||||
|
if readMsg.streamID.Load() == remotecommand.StreamResize {
|
||||||
|
var err error
|
||||||
|
var msg tsrecorder.ResizeMsg
|
||||||
|
if err = json.Unmarshal(readMsg.payload, &msg); err != nil {
|
||||||
|
return 0, fmt.Errorf("error umarshalling resize message: %w", err)
|
||||||
|
}
|
||||||
|
c.ch.Width = msg.Width
|
||||||
|
c.ch.Height = msg.Height
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.currentReadMsg = readMsg
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write parses the written bytes as WebSocket message fragment. If the message
|
||||||
|
// is for stdout or stderr streams, it is written to the configured tsrecorder.
|
||||||
|
// A message fragment can be incomplete.
|
||||||
|
func (c *conn) Write(b []byte) (int, error) {
|
||||||
|
c.wmu.Lock()
|
||||||
|
defer c.wmu.Unlock()
|
||||||
|
if len(b) == 0 {
|
||||||
|
c.log.Debug("[unexpected] Write called with 0 bytes")
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
typ := messageType(opcode(b))
|
||||||
|
// If we are in process of parsing a message fragment, the received
|
||||||
|
// bytes are not structured as a message fragment and can not be used to
|
||||||
|
// determine a message fragment.
|
||||||
|
if c.writeBufHasIncompleteFragment() { // buffer contains previous incomplete fragment
|
||||||
|
var err error
|
||||||
|
if typ, err = c.curWriteMsgType(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isControlMessage(typ) {
|
||||||
|
return c.Conn.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
writeMsg := &message{typ: typ} // start a new message...
|
||||||
|
// ... or continue the existing one if it has not been finalized.
|
||||||
|
if c.writeMsgIsIncomplete() || c.writeBufHasIncompleteFragment() {
|
||||||
|
writeMsg = c.currentWriteMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := c.writeBuf.Write(b); err != nil {
|
||||||
|
c.log.Errorf("write: error writing to write buf: %v", err)
|
||||||
|
return 0, fmt.Errorf("[unexpected] error writing to internal write buffer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := writeMsg.Parse(c.writeBuf.Bytes(), c.log)
|
||||||
|
if err != nil {
|
||||||
|
c.log.Errorf("write: parsing a message errored: %v", err)
|
||||||
|
return 0, fmt.Errorf("write: error parsing message: %v", err)
|
||||||
|
}
|
||||||
|
c.currentWriteMsg = writeMsg
|
||||||
|
if !ok { // incomplete fragment
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
c.writeBuf.Next(len(writeMsg.raw)) // advance frame
|
||||||
|
|
||||||
|
if len(writeMsg.payload) != 0 && writeMsg.isFinalized {
|
||||||
|
if writeMsg.streamID.Load() == remotecommand.StreamStdOut || writeMsg.streamID.Load() == remotecommand.StreamStdErr {
|
||||||
|
var err error
|
||||||
|
c.writeCastHeaderOnce.Do(func() {
|
||||||
|
var j []byte
|
||||||
|
j, err = json.Marshal(c.ch)
|
||||||
|
if err != nil {
|
||||||
|
c.log.Errorf("error marhsalling conn: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
j = append(j, '\n')
|
||||||
|
err = c.rec.WriteCastLine(j)
|
||||||
|
if err != nil {
|
||||||
|
c.log.Errorf("received error from recorder: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("error writing CastHeader: %w", err)
|
||||||
|
}
|
||||||
|
if err := c.rec.Write(writeMsg.payload); err != nil {
|
||||||
|
return 0, fmt.Errorf("error writing message to recorder: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err = c.Conn.Write(c.currentWriteMsg.raw)
|
||||||
|
if err != nil {
|
||||||
|
c.log.Errorf("write: error writing to conn: %v", err)
|
||||||
|
}
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) Close() error {
|
||||||
|
c.wmu.Lock()
|
||||||
|
defer c.wmu.Unlock()
|
||||||
|
if c.closed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c.closed = true
|
||||||
|
connCloseErr := c.Conn.Close()
|
||||||
|
recCloseErr := c.rec.Close()
|
||||||
|
return multierr.New(connCloseErr, recCloseErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeBufHasIncompleteFragment returns true if the latest data message
|
||||||
|
// fragment written to the connection was incomplete and the following write
|
||||||
|
// must be the remaining payload bytes of that fragment.
|
||||||
|
func (c *conn) writeBufHasIncompleteFragment() bool {
|
||||||
|
return c.writeBuf.Len() != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// readBufHasIncompleteFragment returns true if the latest data message
|
||||||
|
// fragment read from the connection was incomplete and the following read
|
||||||
|
// must be the remaining payload bytes of that fragment.
|
||||||
|
func (c *conn) readBufHasIncompleteFragment() bool {
|
||||||
|
return c.readBuf.Len() != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeMsgIsIncomplete returns true if the latest WebSocket message written to
|
||||||
|
// the connection was fragmented and the next data message fragment written to
|
||||||
|
// the connection must be a fragment of that message.
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc6455#section-5.4
|
||||||
|
func (c *conn) writeMsgIsIncomplete() bool {
|
||||||
|
return c.currentWriteMsg != nil && !c.currentWriteMsg.isFinalized
|
||||||
|
}
|
||||||
|
|
||||||
|
// readMsgIsIncomplete returns true if the latest WebSocket message written to
|
||||||
|
// the connection was fragmented and the next data message fragment written to
|
||||||
|
// the connection must be a fragment of that message.
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc6455#section-5.4
|
||||||
|
func (c *conn) readMsgIsIncomplete() bool {
|
||||||
|
return c.currentReadMsg != nil && !c.currentReadMsg.isFinalized
|
||||||
|
}
|
||||||
|
func (c *conn) curReadMsgType() (messageType, error) {
|
||||||
|
if c.currentReadMsg != nil {
|
||||||
|
return c.currentReadMsg.typ, nil
|
||||||
|
}
|
||||||
|
return 0, errors.New("[unexpected] attempted to determine type for nil message")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) curWriteMsgType() (messageType, error) {
|
||||||
|
if c.currentWriteMsg != nil {
|
||||||
|
return c.currentWriteMsg.typ, nil
|
||||||
|
}
|
||||||
|
return 0, errors.New("[unexpected] attempted to determine type for nil message")
|
||||||
|
}
|
||||||
|
|
||||||
|
// opcode reads the websocket message opcode that denotes the message type.
|
||||||
|
// opcode is contained in bits [4-8] of the message.
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||||
|
func opcode(b []byte) int {
|
||||||
|
// 0xf = 00001111; b & 00001111 zeroes out bits [0 - 3] of b
|
||||||
|
var mask byte = 0xf
|
||||||
|
return int(b[0] & mask)
|
||||||
|
}
|
@ -0,0 +1,257 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
//go:build !plan9
|
||||||
|
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"k8s.io/apimachinery/pkg/util/remotecommand"
|
||||||
|
"tailscale.com/k8s-operator/sessionrecording/fakes"
|
||||||
|
"tailscale.com/k8s-operator/sessionrecording/tsrecorder"
|
||||||
|
"tailscale.com/sessionrecording"
|
||||||
|
"tailscale.com/tstest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_conn_Read(t *testing.T) {
|
||||||
|
zl, err := zap.NewDevelopment()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// Resize stream ID + {"width": 10, "height": 20}
|
||||||
|
testResizeMsg := []byte{byte(remotecommand.StreamResize), 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d}
|
||||||
|
lenResizeMsgPayload := byte(len(testResizeMsg))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
inputs [][]byte
|
||||||
|
wantWidth int
|
||||||
|
wantHeight int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single_read_control_message",
|
||||||
|
inputs: [][]byte{{0x88, 0x0}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single_read_resize_message",
|
||||||
|
inputs: [][]byte{append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...)},
|
||||||
|
wantWidth: 10,
|
||||||
|
wantHeight: 20,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two_reads_resize_message",
|
||||||
|
inputs: [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d}},
|
||||||
|
wantWidth: 10,
|
||||||
|
wantHeight: 20,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "three_reads_resize_message_with_split_fragment",
|
||||||
|
inputs: [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74}, {0x22, 0x3a, 0x32, 0x30, 0x7d}},
|
||||||
|
wantWidth: 10,
|
||||||
|
wantHeight: 20,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tc := &fakes.TestConn{}
|
||||||
|
tc.ResetReadBuf()
|
||||||
|
c := &conn{
|
||||||
|
Conn: tc,
|
||||||
|
log: zl.Sugar(),
|
||||||
|
}
|
||||||
|
for i, input := range tt.inputs {
|
||||||
|
if err := tc.WriteReadBufBytes(input); err != nil {
|
||||||
|
t.Fatalf("writing bytes to test conn: %v", err)
|
||||||
|
}
|
||||||
|
_, err := c.Read(make([]byte, len(input)))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("[%d] conn.Read() errored %v", i, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tt.wantHeight != 0 || tt.wantWidth != 0 {
|
||||||
|
if tt.wantWidth != c.ch.Width {
|
||||||
|
t.Errorf("wants width: %v, got %v", tt.wantWidth, c.ch.Width)
|
||||||
|
}
|
||||||
|
if tt.wantHeight != c.ch.Height {
|
||||||
|
t.Errorf("want height: %v, got %v", tt.wantHeight, c.ch.Height)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_conn_Write(t *testing.T) {
|
||||||
|
zl, err := zap.NewDevelopment()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
cl := tstest.NewClock(tstest.ClockOpts{})
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
inputs [][]byte
|
||||||
|
wantForwarded []byte
|
||||||
|
wantRecorded []byte
|
||||||
|
firstWrite bool
|
||||||
|
width int
|
||||||
|
height int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single_write_control_frame",
|
||||||
|
inputs: [][]byte{{0x88, 0x0}},
|
||||||
|
wantForwarded: []byte{0x88, 0x0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single_write_stdout_data_message",
|
||||||
|
inputs: [][]byte{{0x82, 0x3, 0x1, 0x7, 0x8}},
|
||||||
|
wantForwarded: []byte{0x82, 0x3, 0x1, 0x7, 0x8},
|
||||||
|
wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8}, cl),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single_write_stderr_data_message",
|
||||||
|
inputs: [][]byte{{0x82, 0x3, 0x2, 0x7, 0x8}},
|
||||||
|
wantForwarded: []byte{0x82, 0x3, 0x2, 0x7, 0x8},
|
||||||
|
wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8}, cl),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single_write_stdin_data_message",
|
||||||
|
inputs: [][]byte{{0x82, 0x3, 0x0, 0x7, 0x8}},
|
||||||
|
wantForwarded: []byte{0x82, 0x3, 0x0, 0x7, 0x8},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single_write_stdout_data_message_with_cast_header",
|
||||||
|
inputs: [][]byte{{0x82, 0x3, 0x1, 0x7, 0x8}},
|
||||||
|
wantForwarded: []byte{0x82, 0x3, 0x1, 0x7, 0x8},
|
||||||
|
wantRecorded: append(fakes.AsciinemaResizeMsg(t, 10, 20), fakes.CastLine(t, []byte{0x7, 0x8}, cl)...),
|
||||||
|
width: 10,
|
||||||
|
height: 20,
|
||||||
|
firstWrite: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two_writes_stdout_data_message",
|
||||||
|
inputs: [][]byte{{0x2, 0x3, 0x1, 0x7, 0x8}, {0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5}},
|
||||||
|
wantForwarded: []byte{0x2, 0x3, 0x1, 0x7, 0x8, 0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5},
|
||||||
|
wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "three_writes_stdout_data_message_with_split_fragment",
|
||||||
|
inputs: [][]byte{{0x2, 0x3, 0x1, 0x7, 0x8}, {0x80, 0x6, 0x1, 0x1, 0x2, 0x3}, {0x4, 0x5}},
|
||||||
|
wantForwarded: []byte{0x2, 0x3, 0x1, 0x7, 0x8, 0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5},
|
||||||
|
wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tc := &fakes.TestConn{}
|
||||||
|
sr := &fakes.TestSessionRecorder{}
|
||||||
|
rec := tsrecorder.New(sr, cl, cl.Now(), true)
|
||||||
|
c := &conn{
|
||||||
|
Conn: tc,
|
||||||
|
log: zl.Sugar(),
|
||||||
|
ch: sessionrecording.CastHeader{
|
||||||
|
Width: tt.width,
|
||||||
|
Height: tt.height,
|
||||||
|
},
|
||||||
|
rec: rec,
|
||||||
|
}
|
||||||
|
if !tt.firstWrite {
|
||||||
|
// This test case does not intend to test that cast header gets written once.
|
||||||
|
c.writeCastHeaderOnce.Do(func() {})
|
||||||
|
}
|
||||||
|
for i, input := range tt.inputs {
|
||||||
|
_, err := c.Write(input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("[%d] conn.Write() errored: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Assert that the expected bytes have been forwarded to the original destination.
|
||||||
|
gotForwarded := tc.WriteBufBytes()
|
||||||
|
if !reflect.DeepEqual(gotForwarded, tt.wantForwarded) {
|
||||||
|
t.Errorf("expected bytes not forwarded, wants\n%x\ngot\n%x", tt.wantForwarded, gotForwarded)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert that the expected bytes have been forwarded to the session recorder.
|
||||||
|
gotRecorded := sr.Bytes()
|
||||||
|
if !reflect.DeepEqual(gotRecorded, tt.wantRecorded) {
|
||||||
|
t.Errorf("expected bytes not recorded, wants\n%b\ngot\n%b", tt.wantRecorded, gotRecorded)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test_conn_ReadRand tests reading arbitrarily generated byte slices from conn to
|
||||||
|
// test that we don't panic when parsing input from a broken or malicious
|
||||||
|
// client.
|
||||||
|
func Test_conn_ReadRand(t *testing.T) {
|
||||||
|
zl, err := zap.NewDevelopment()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error creating a test logger: %v", err)
|
||||||
|
}
|
||||||
|
for i := range 100 {
|
||||||
|
tc := &fakes.TestConn{}
|
||||||
|
tc.ResetReadBuf()
|
||||||
|
c := &conn{
|
||||||
|
Conn: tc,
|
||||||
|
log: zl.Sugar(),
|
||||||
|
}
|
||||||
|
bb := fakes.RandomBytes(t)
|
||||||
|
for j, input := range bb {
|
||||||
|
if err := tc.WriteReadBufBytes(input); err != nil {
|
||||||
|
t.Fatalf("[%d] writing bytes to test conn: %v", i, err)
|
||||||
|
}
|
||||||
|
f := func() {
|
||||||
|
c.Read(make([]byte, len(input)))
|
||||||
|
}
|
||||||
|
testPanic(t, f, fmt.Sprintf("[%d %d] Read panic parsing input of length %d first bytes: %v, current read message: %+#v", i, j, len(input), firstBytes(input), c.currentReadMsg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test_conn_WriteRand calls conn.Write with an arbitrary input to validate that it does not
|
||||||
|
// panic.
|
||||||
|
func Test_conn_WriteRand(t *testing.T) {
|
||||||
|
zl, err := zap.NewDevelopment()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error creating a test logger: %v", err)
|
||||||
|
}
|
||||||
|
cl := tstest.NewClock(tstest.ClockOpts{})
|
||||||
|
sr := &fakes.TestSessionRecorder{}
|
||||||
|
rec := tsrecorder.New(sr, cl, cl.Now(), true)
|
||||||
|
for i := range 100 {
|
||||||
|
tc := &fakes.TestConn{}
|
||||||
|
c := &conn{
|
||||||
|
Conn: tc,
|
||||||
|
log: zl.Sugar(),
|
||||||
|
rec: rec,
|
||||||
|
}
|
||||||
|
bb := fakes.RandomBytes(t)
|
||||||
|
for j, input := range bb {
|
||||||
|
f := func() {
|
||||||
|
c.Write(input)
|
||||||
|
}
|
||||||
|
testPanic(t, f, fmt.Sprintf("[%d %d] Write: panic parsing input of length %d first bytes %b current write message %+#v", i, j, len(input), firstBytes(input), c.currentWriteMsg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testPanic(t *testing.T, f func(), msg string) {
|
||||||
|
t.Helper()
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Fatal(msg, r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
f()
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstBytes(b []byte) []byte {
|
||||||
|
if len(b) < 10 {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
return b[:10]
|
||||||
|
}
|
@ -0,0 +1,267 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
//go:build !plan9
|
||||||
|
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
"golang.org/x/net/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
noOpcode messageType = 0 // continuation frame for fragmented messages
|
||||||
|
binaryMessage messageType = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// messageType is the type of a websocket data or control message as defined by opcode.
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||||
|
// Known types of control messages are close, ping and pong.
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc6455#section-5.5
|
||||||
|
// The only data message type supported by Kubernetes is binary message
|
||||||
|
// https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L281
|
||||||
|
type messageType int
|
||||||
|
|
||||||
|
// message is a parsed Websocket Message.
|
||||||
|
type message struct {
|
||||||
|
// payload is the contents of the so far parsed Websocket
|
||||||
|
// data Message payload, potentially from multiple fragments written by
|
||||||
|
// multiple invocations of Parse. As per RFC 6455 We can assume that the
|
||||||
|
// fragments will always arrive in order and data messages will not be
|
||||||
|
// interleaved.
|
||||||
|
payload []byte
|
||||||
|
|
||||||
|
// isFinalized is set to true if msgPayload contains full contents of
|
||||||
|
// the message (the final fragment has been received).
|
||||||
|
isFinalized bool
|
||||||
|
|
||||||
|
// streamID is the stream to which the message belongs, i.e stdin, stout
|
||||||
|
// etc. It is one of the stream IDs defined in
|
||||||
|
// https://github.com/kubernetes/apimachinery/blob/73d12d09c5be8703587b5127416eb83dc3b7e182/pkg/util/httpstream/wsstream/doc.go#L23-L36
|
||||||
|
streamID atomic.Uint32
|
||||||
|
|
||||||
|
// typ is the type of a WebsocketMessage as defined by its opcode
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||||
|
typ messageType
|
||||||
|
raw []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse accepts a websocket message fragment as a byte slice and parses its contents.
|
||||||
|
// It returns true if the fragment is complete, false if the fragment is incomplete.
|
||||||
|
// If the fragment is incomplete, Parse will be called again with the same fragment + more bytes when those are received.
|
||||||
|
// If the fragment is complete, it will be parsed into msg.
|
||||||
|
// A complete fragment can be:
|
||||||
|
// - a fragment that consists of a whole message
|
||||||
|
// - an initial fragment for a message for which we expect more fragments
|
||||||
|
// - a subsequent fragment for a message that we are currently parsing and whose so-far parsed contents are stored in msg.
|
||||||
|
// Parse must not be called with bytes that don't contain fragment header (so, no less than 2 bytes).
|
||||||
|
// 0 1 2 3
|
||||||
|
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||||
|
// +-+-+-+-+-------+-+-------------+-------------------------------+
|
||||||
|
// |F|R|R|R| opcode|M| Payload len | Extended payload length |
|
||||||
|
// |I|S|S|S| (4) |A| (7) | (16/64) |
|
||||||
|
// |N|V|V|V| |S| | (if payload len==126/127) |
|
||||||
|
// | |1|2|3| |K| | |
|
||||||
|
// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|
||||||
|
// | Extended payload length continued, if payload len == 127 |
|
||||||
|
// + - - - - - - - - - - - - - - - +-------------------------------+
|
||||||
|
// | |Masking-key, if MASK set to 1 |
|
||||||
|
// +-------------------------------+-------------------------------+
|
||||||
|
// | Masking-key (continued) | Payload Data |
|
||||||
|
// +-------------------------------- - - - - - - - - - - - - - - - +
|
||||||
|
// : Payload Data continued ... :
|
||||||
|
// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
|
||||||
|
// | Payload Data continued ... |
|
||||||
|
// +---------------------------------------------------------------+
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||||
|
//
|
||||||
|
// Fragmentation rules:
|
||||||
|
// An unfragmented message consists of a single frame with the FIN
|
||||||
|
// bit set (Section 5.2) and an opcode other than 0.
|
||||||
|
// A fragmented message consists of a single frame with the FIN bit
|
||||||
|
// clear and an opcode other than 0, followed by zero or more frames
|
||||||
|
// with the FIN bit clear and the opcode set to 0, and terminated by
|
||||||
|
// a single frame with the FIN bit set and an opcode of 0.
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc6455#section-5.4
|
||||||
|
func (msg *message) Parse(b []byte, log *zap.SugaredLogger) (bool, error) {
|
||||||
|
if len(b) < 2 {
|
||||||
|
return false, fmt.Errorf("[unexpected] Parse should not be called with less than 2 bytes, got %d bytes", len(b))
|
||||||
|
}
|
||||||
|
if msg.typ != binaryMessage {
|
||||||
|
return false, fmt.Errorf("[unexpected] internal error: attempted to parse a message with type %d", msg.typ)
|
||||||
|
}
|
||||||
|
isInitialFragment := len(msg.raw) == 0
|
||||||
|
|
||||||
|
msg.isFinalized = isFinalFragment(b)
|
||||||
|
|
||||||
|
maskSet := isMasked(b)
|
||||||
|
|
||||||
|
payloadLength, payloadOffset, maskOffset, err := fragmentDimensions(b, maskSet)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("error determining payload length: %w", err)
|
||||||
|
}
|
||||||
|
log.Debugf("parse: parsing a message fragment with payload length: %d payload offset: %d maskOffset: %d mask set: %t, is finalized: %t, is initial fragment: %t", payloadLength, payloadOffset, maskOffset, maskSet, msg.isFinalized, isInitialFragment)
|
||||||
|
|
||||||
|
if len(b) < int(payloadOffset+payloadLength) { // incomplete fragment
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
// TODO (irbekrm): perhaps only do this extra allocation if we know we
|
||||||
|
// will need to unmask?
|
||||||
|
msg.raw = make([]byte, int(payloadOffset)+int(payloadLength))
|
||||||
|
copy(msg.raw, b[:payloadOffset+payloadLength])
|
||||||
|
|
||||||
|
// Extract the payload.
|
||||||
|
msgPayload := b[payloadOffset : payloadOffset+payloadLength]
|
||||||
|
|
||||||
|
// Unmask the payload if needed.
|
||||||
|
// TODO (irbekrm): instead of unmasking all of the payload each time,
|
||||||
|
// determine if the payload is for a resize message early and skip
|
||||||
|
// unmasking the remaining bytes if not.
|
||||||
|
if maskSet {
|
||||||
|
m := b[maskOffset:payloadOffset]
|
||||||
|
var mask [4]byte
|
||||||
|
copy(mask[:], m)
|
||||||
|
maskBytes(mask, msgPayload)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine what stream the message is for. Stream ID of a Kubernetes
|
||||||
|
// streaming session is a 32bit integer, stored in the first byte of the
|
||||||
|
// message payload.
|
||||||
|
// https://github.com/kubernetes/apimachinery/commit/73d12d09c5be8703587b5127416eb83dc3b7e182#diff-291f96e8632d04d2d20f5fb00f6b323492670570d65434e8eac90c7a442d13bdR23-R36
|
||||||
|
if len(msgPayload) == 0 {
|
||||||
|
return false, errors.New("[unexpected] received a message fragment with no stream ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
streamID := uint32(msgPayload[0])
|
||||||
|
if !isInitialFragment && msg.streamID.Load() != streamID {
|
||||||
|
return false, fmt.Errorf("[unexpected] received message fragments with mismatched streamIDs %d and %d", msg.streamID.Load(), streamID)
|
||||||
|
}
|
||||||
|
msg.streamID.Store(streamID)
|
||||||
|
|
||||||
|
// This is normal, Kubernetes seem to send a couple data messages with
|
||||||
|
// no payloads at the start.
|
||||||
|
if len(msgPayload) < 2 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
msgPayload = msgPayload[1:] // remove the stream ID byte
|
||||||
|
msg.payload = append(msg.payload, msgPayload...)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// maskBytes applies mask to bytes in place.
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc6455#section-5.3
|
||||||
|
func maskBytes(key [4]byte, b []byte) {
|
||||||
|
for i := range b {
|
||||||
|
b[i] = b[i] ^ key[i%4]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isControlMessage returns true if the message type is one of the known control
|
||||||
|
// frame message types.
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc6455#section-5.5
|
||||||
|
func isControlMessage(t messageType) bool {
|
||||||
|
const (
|
||||||
|
closeMessage messageType = 8
|
||||||
|
pingMessage messageType = 9
|
||||||
|
pongMessage messageType = 10
|
||||||
|
)
|
||||||
|
return t == closeMessage || t == pingMessage || t == pongMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// isFinalFragment can be called with websocket message fragment and returns true if
|
||||||
|
// the fragment is the final fragment of a websocket message.
|
||||||
|
func isFinalFragment(b []byte) bool {
|
||||||
|
return extractFirstBit(b[0]) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// isMasked can be called with a websocket message fragment and returns true if
|
||||||
|
// the payload of the message is masked. It uses the mask bit to determine if
|
||||||
|
// the payload is masked.
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc6455#section-5.3
|
||||||
|
func isMasked(b []byte) bool {
|
||||||
|
return extractFirstBit(b[1]) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractFirstBit extracts first bit of a byte by zeroing out all the other
|
||||||
|
// bits.
|
||||||
|
func extractFirstBit(b byte) byte {
|
||||||
|
return b & 0x80
|
||||||
|
}
|
||||||
|
|
||||||
|
// zeroFirstBit returns the provided byte with the first bit set to 0.
|
||||||
|
func zeroFirstBit(b byte) byte {
|
||||||
|
return b & 0x7f
|
||||||
|
}
|
||||||
|
|
||||||
|
// fragmentDimensions returns payload length as well as payload offset and mask offset.
|
||||||
|
func fragmentDimensions(b []byte, maskSet bool) (payloadLength, payloadOffset, maskOffset uint64, _ error) {
|
||||||
|
|
||||||
|
// payload length can be stored either in bits [9-15] or in bytes 2, 3
|
||||||
|
// or in bytes 2, 3, 4, 5, 6, 7.
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||||
|
// 0 1 2 3
|
||||||
|
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||||
|
// +-+-+-+-+-------+-+-------------+-------------------------------+
|
||||||
|
// |F|R|R|R| opcode|M| Payload len | Extended payload length |
|
||||||
|
// |I|S|S|S| (4) |A| (7) | (16/64) |
|
||||||
|
// |N|V|V|V| |S| | (if payload len==126/127) |
|
||||||
|
// | |1|2|3| |K| | |
|
||||||
|
// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|
||||||
|
// | Extended payload length continued, if payload len == 127 |
|
||||||
|
// + - - - - - - - - - - - - - - - +-------------------------------+
|
||||||
|
// | |Masking-key, if MASK set to 1 |
|
||||||
|
// +-------------------------------+-------------------------------+
|
||||||
|
payloadLengthIndicator := zeroFirstBit(b[1])
|
||||||
|
switch {
|
||||||
|
case payloadLengthIndicator < 126:
|
||||||
|
maskOffset = 2
|
||||||
|
payloadLength = uint64(payloadLengthIndicator)
|
||||||
|
case payloadLengthIndicator == 126:
|
||||||
|
maskOffset = 4
|
||||||
|
if len(b) < int(maskOffset) {
|
||||||
|
return 0, 0, 0, fmt.Errorf("invalid message fragment- length indicator suggests that length is stored in bytes 2:4, but message length is only %d", len(b))
|
||||||
|
}
|
||||||
|
payloadLength = uint64(binary.BigEndian.Uint16(b[2:4]))
|
||||||
|
case payloadLengthIndicator == 127:
|
||||||
|
maskOffset = 10
|
||||||
|
if len(b) < int(maskOffset) {
|
||||||
|
return 0, 0, 0, fmt.Errorf("invalid message fragment- length indicator suggests that length is stored in bytes 2:10, but message length is only %d", len(b))
|
||||||
|
}
|
||||||
|
payloadLength = binary.BigEndian.Uint64(b[2:10])
|
||||||
|
default:
|
||||||
|
return 0, 0, 0, fmt.Errorf("unexpected payload length indicator value: %v", payloadLengthIndicator)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that a rogue or broken client doesn't cause us attempt to
|
||||||
|
// allocate a huge array by setting a high payload size.
|
||||||
|
// websocket.DefaultMaxPayloadBytes is the maximum payload size accepted
|
||||||
|
// by server side of this connection, so we can safely reject messages
|
||||||
|
// with larger payload size.
|
||||||
|
if payloadLength > websocket.DefaultMaxPayloadBytes {
|
||||||
|
return 0, 0, 0, fmt.Errorf("[unexpected]: too large payload size: %v", payloadLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Masking key can take up 0 or 4 bytes- we need to take that into
|
||||||
|
// account when determining payload offset.
|
||||||
|
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||||
|
// ....
|
||||||
|
// + - - - - - - - - - - - - - - - +-------------------------------+
|
||||||
|
// | |Masking-key, if MASK set to 1 |
|
||||||
|
// +-------------------------------+-------------------------------+
|
||||||
|
// | Masking-key (continued) | Payload Data |
|
||||||
|
// + - - - - - - - - - - - - - - - +-------------------------------+
|
||||||
|
// ...
|
||||||
|
if maskSet {
|
||||||
|
payloadOffset = maskOffset + 4
|
||||||
|
} else {
|
||||||
|
payloadOffset = maskOffset
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
@ -0,0 +1,215 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
//go:build !plan9
|
||||||
|
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"math/rand"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"golang.org/x/net/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_msg_Parse(t *testing.T) {
|
||||||
|
zl, err := zap.NewDevelopment()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error creating a test logger: %v", err)
|
||||||
|
}
|
||||||
|
testMask := [4]byte{1, 2, 3, 4}
|
||||||
|
bs126, bs126Len := bytesSlice2ByteLen(t)
|
||||||
|
bs127, bs127Len := byteSlice8ByteLen(t)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
b []byte
|
||||||
|
initialPayload []byte
|
||||||
|
wantPayload []byte
|
||||||
|
wantIsFinalized bool
|
||||||
|
wantStreamID uint32
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single_fragment_stdout_stream_no_payload_no_mask",
|
||||||
|
b: []byte{0x82, 0x1, 0x1},
|
||||||
|
wantPayload: nil,
|
||||||
|
wantIsFinalized: true,
|
||||||
|
wantStreamID: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single_fragment_stderr_steam_no_payload_has_mask",
|
||||||
|
b: append([]byte{0x82, 0x81, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x2})...),
|
||||||
|
wantPayload: nil,
|
||||||
|
wantIsFinalized: true,
|
||||||
|
wantStreamID: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single_fragment_stdout_stream_no_mask_has_payload",
|
||||||
|
b: []byte{0x82, 0x3, 0x1, 0x7, 0x8},
|
||||||
|
wantPayload: []byte{0x7, 0x8},
|
||||||
|
wantIsFinalized: true,
|
||||||
|
wantStreamID: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single_fragment_stdout_stream_has_mask_has_payload",
|
||||||
|
b: append([]byte{0x82, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
|
||||||
|
wantPayload: []byte{0x7, 0x8},
|
||||||
|
wantIsFinalized: true,
|
||||||
|
wantStreamID: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "initial_fragment_stdout_stream_no_mask_has_payload",
|
||||||
|
b: []byte{0x2, 0x3, 0x1, 0x7, 0x8},
|
||||||
|
wantPayload: []byte{0x7, 0x8},
|
||||||
|
wantStreamID: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "initial_fragment_stdout_stream_has_mask_has_payload",
|
||||||
|
b: append([]byte{0x2, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
|
||||||
|
wantPayload: []byte{0x7, 0x8},
|
||||||
|
wantStreamID: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subsequent_fragment_stdout_stream_no_mask_has_payload",
|
||||||
|
b: []byte{0x0, 0x3, 0x1, 0x7, 0x8},
|
||||||
|
initialPayload: []byte{0x1, 0x2, 0x3},
|
||||||
|
wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8},
|
||||||
|
wantStreamID: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subsequent_fragment_stdout_stream_has_mask_has_payload",
|
||||||
|
b: append([]byte{0x0, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
|
||||||
|
initialPayload: []byte{0x1, 0x2, 0x3},
|
||||||
|
wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8},
|
||||||
|
wantStreamID: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "final_fragment_stdout_stream_no_mask_has_payload",
|
||||||
|
b: []byte{0x80, 0x3, 0x1, 0x7, 0x8},
|
||||||
|
initialPayload: []byte{0x1, 0x2, 0x3},
|
||||||
|
wantIsFinalized: true,
|
||||||
|
wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8},
|
||||||
|
wantStreamID: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "final_fragment_stdout_stream_has_mask_has_payload",
|
||||||
|
b: append([]byte{0x80, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
|
||||||
|
initialPayload: []byte{0x1, 0x2, 0x3},
|
||||||
|
wantIsFinalized: true,
|
||||||
|
wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8},
|
||||||
|
wantStreamID: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single_large_fragment_no_mask_length_hint_126",
|
||||||
|
b: append(append([]byte{0x80, 0x7e}, bs126Len...), append([]byte{0x1}, bs126...)...),
|
||||||
|
wantIsFinalized: true,
|
||||||
|
wantPayload: bs126,
|
||||||
|
wantStreamID: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single_large_fragment_no_mask_length_hint_127",
|
||||||
|
b: append(append([]byte{0x80, 0x7f}, bs127Len...), append([]byte{0x1}, bs127...)...),
|
||||||
|
wantIsFinalized: true,
|
||||||
|
wantPayload: bs127,
|
||||||
|
wantStreamID: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero_length_bytes",
|
||||||
|
b: []byte{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
msg := &message{
|
||||||
|
typ: binaryMessage,
|
||||||
|
payload: tt.initialPayload,
|
||||||
|
}
|
||||||
|
if _, err := msg.Parse(tt.b, zl.Sugar()); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("msg.Parse() = %v, wantsErr: %t", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if msg.isFinalized != tt.wantIsFinalized {
|
||||||
|
t.Errorf("wants message to be finalized: %t, got: %t", tt.wantIsFinalized, msg.isFinalized)
|
||||||
|
}
|
||||||
|
if msg.streamID.Load() != tt.wantStreamID {
|
||||||
|
t.Errorf("wants stream ID: %d, got: %d", tt.wantStreamID, msg.streamID.Load())
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(msg.payload, tt.wantPayload) {
|
||||||
|
t.Errorf("unexpected message payload after Parse, wants %b got %b", tt.wantPayload, msg.payload)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test_msg_Parse_Rand calls Parse with a randomly generated input to verify
|
||||||
|
// that it doesn't panic.
|
||||||
|
func Test_msg_Parse_Rand(t *testing.T) {
|
||||||
|
zl, err := zap.NewDevelopment()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error creating a test logger: %v", err)
|
||||||
|
}
|
||||||
|
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
for i := range 100 {
|
||||||
|
n := r.Intn(4096)
|
||||||
|
b := make([]byte, n)
|
||||||
|
_, err := r.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error generating random byte slice: %v", err)
|
||||||
|
}
|
||||||
|
msg := message{typ: binaryMessage}
|
||||||
|
f := func() {
|
||||||
|
msg.Parse(b, zl.Sugar())
|
||||||
|
}
|
||||||
|
testPanic(t, f, fmt.Sprintf("[%d] Parse panicked running with byte slice of length %d: %v", i, n, r))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// byteSlice2ByteLen generates a number that represents websocket message fragment length and is stored in an 8 byte slice.
|
||||||
|
// Returns the byte slice with the length as well as a slice of arbitrary bytes of the given length.
|
||||||
|
// This is used to generate test input representing websocket message with payload length hint 126.
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||||
|
func bytesSlice2ByteLen(t *testing.T) ([]byte, []byte) {
|
||||||
|
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
var n uint16
|
||||||
|
n = uint16(rand.Intn(65535 - 1)) // space for and additional 1 byte stream ID
|
||||||
|
b := make([]byte, n)
|
||||||
|
_, err := r.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error generating random byte slice: %v ", err)
|
||||||
|
}
|
||||||
|
bb := make([]byte, 2)
|
||||||
|
binary.BigEndian.PutUint16(bb, n+1) // + stream ID
|
||||||
|
return b, bb
|
||||||
|
}
|
||||||
|
|
||||||
|
// byteSlice8ByteLen generates a number that represents websocket message fragment length and is stored in an 8 byte slice.
|
||||||
|
// Returns the byte slice with the length as well as a slice of arbitrary bytes of the given length.
|
||||||
|
// This is used to generate test input representing websocket message with payload length hint 127.
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||||
|
func byteSlice8ByteLen(t *testing.T) ([]byte, []byte) {
|
||||||
|
nanos := time.Now().UnixNano()
|
||||||
|
t.Logf("Creating random source with seed %v", nanos)
|
||||||
|
r := rand.New(rand.NewSource(nanos))
|
||||||
|
var n uint64
|
||||||
|
n = uint64(rand.Intn(websocket.DefaultMaxPayloadBytes - 1)) // space for and additional 1 byte stream ID
|
||||||
|
t.Logf("byteSlice8ByteLen: generating message payload of length %d", n)
|
||||||
|
b := make([]byte, n)
|
||||||
|
_, err := r.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error generating random byte slice: %v ", err)
|
||||||
|
}
|
||||||
|
bb := make([]byte, 8)
|
||||||
|
binary.BigEndian.PutUint64(bb, n+1) // + stream ID
|
||||||
|
return b, bb
|
||||||
|
}
|
||||||
|
|
||||||
|
func maskedBytes(mask [4]byte, b []byte) []byte {
|
||||||
|
maskBytes(mask, b)
|
||||||
|
return b
|
||||||
|
}
|
Loading…
Reference in New Issue