control/controlclient, tailcfg: add support for EarlyNoise payload before http/2

Not yet used, but skipped over, parsed, and tested.

Updates #5972

Change-Id: Icd00196959ce266ae16a6c9244bd5e458e2c2947
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/6168/head
Brad Fitzpatrick 2 years ago committed by Brad Fitzpatrick
parent a7f7e79245
commit 988c1f0ac7

@ -7,7 +7,10 @@ package controlclient
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/binary"
"encoding/json" "encoding/json"
"errors"
"io"
"math" "math"
"net/http" "net/http"
"net/url" "net/url"
@ -34,6 +37,77 @@ type noiseConn struct {
id int id int
pool *noiseClient pool *noiseClient
h2cc *http2.ClientConn h2cc *http2.ClientConn
readHeaderOnce sync.Once // guards init of reader field
reader io.Reader // (effectively Conn.Reader after header)
earlyPayloadReady chan struct{} // closed after earlyPayload is set (including set to nil)
earlyPayload *tailcfg.EarlyNoise
}
func (c *noiseConn) RoundTrip(r *http.Request) (*http.Response, error) {
return c.h2cc.RoundTrip(r)
}
// The first 9 bytes from the server to client over Noise are either an HTTP/2
// settings frame (a normal HTTP/2 setup) or, as we added later, an "early payload"
// header that's also 9 bytes long: 5 bytes (earlyPayloadMagic) followed by 4 bytes
// of length. Then that many bytes of JSON-encoded tailcfg.EarlyNoise.
// The early payload is optional. Some servers may not send it.
const (
hdrLen = 9 // http2 frame header size; also size of our early payload size header
earlyPayloadMagic = "\xff\xff\xffTS"
)
// returnErrReader is an io.Reader that always returns an error.
type returnErrReader struct {
err error // the error to return
}
func (r returnErrReader) Read([]byte) (int, error) { return 0, r.err }
// Read is basically the same as controlbase.Conn.Read, but it first reads the
// "early payload" header from the server which may or may not be present,
// depending on the server.
func (c *noiseConn) Read(p []byte) (n int, err error) {
c.readHeaderOnce.Do(c.readHeader)
return c.reader.Read(p)
}
// readHeader reads the optional "early payload" from the server that arrives
// after the Noise handshake but before the HTTP/2 session begins.
//
// readHeader is responsible for reading the header (if present), initializing
// c.earlyPayload, closing c.earlyPayloadReady, and initializing c.reader for
// future reads.
func (c *noiseConn) readHeader() {
var hdr [hdrLen]byte
if _, err := io.ReadFull(c.Conn, hdr[:]); err != nil {
c.reader = returnErrReader{err}
return
}
if string(hdr[:len(earlyPayloadMagic)]) != earlyPayloadMagic {
// No early payload. We have to return the 9 bytes read we already
// consumed.
close(c.earlyPayloadReady)
c.reader = io.MultiReader(bytes.NewReader(hdr[:]), c.Conn)
return
}
epLen := binary.BigEndian.Uint32(hdr[len(earlyPayloadMagic):])
if epLen > 10<<20 {
c.reader = returnErrReader{errors.New("invalid early payload length")}
return
}
payBuf := make([]byte, epLen)
if _, err := io.ReadFull(c.Conn, payBuf); err != nil {
c.reader = returnErrReader{err}
return
}
if err := json.Unmarshal(payBuf, &c.earlyPayload); err != nil {
c.reader = returnErrReader{err}
return
}
close(c.earlyPayloadReady)
c.reader = c.Conn
} }
func (c *noiseConn) Close() error { func (c *noiseConn) Close() error {
@ -88,7 +162,7 @@ type noiseClient struct {
// serverURL is of the form https://<host>:<port> (no trailing slash). // serverURL is of the form https://<host>:<port> (no trailing slash).
// //
// dialPlan may be nil // dialPlan may be nil
func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer, dialPlan func() *tailcfg.ControlDialPlan) (*noiseClient, error) { func newNoiseClient(privKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer, dialPlan func() *tailcfg.ControlDialPlan) (*noiseClient, error) {
u, err := url.Parse(serverURL) u, err := url.Parse(serverURL)
if err != nil { if err != nil {
return nil, err return nil, err
@ -111,7 +185,7 @@ func newNoiseClient(priKey key.MachinePrivate, serverPubKey key.MachinePublic, s
} }
np := &noiseClient{ np := &noiseClient{
serverPubKey: serverPubKey, serverPubKey: serverPubKey,
privKey: priKey, privKey: privKey,
host: u.Hostname(), host: u.Hostname(),
httpPort: httpPort, httpPort: httpPort,
httpsPort: httpsPort, httpsPort: httpsPort,
@ -157,7 +231,7 @@ func (nc *noiseClient) RoundTrip(req *http.Request) (*http.Response, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return conn.h2cc.RoundTrip(req) return conn.RoundTrip(req)
} }
// connClosed removes the connection with the provided ID from the pool // connClosed removes the connection with the provided ID from the pool
@ -259,14 +333,12 @@ func (nc *noiseClient) dial() (*noiseConn, error) {
} }
ncc := &noiseConn{ ncc := &noiseConn{
Conn: clientConn.Conn, Conn: clientConn.Conn,
id: connID, id: connID,
pool: nc, pool: nc,
earlyPayloadReady: make(chan struct{}),
} }
// TODO(bradfitz): wrap clientConn in a type that sniffs the leading bytes
// from the server to see if it has early post-Noise, pre-H2 data for us.
h2cc, err := nc.h2t.NewClientConn(ncc) h2cc, err := nc.h2t.NewClientConn(ncc)
if err != nil { if err != nil {
return nil, err return nil, err

@ -13,6 +13,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"tailscale.com/control/controlhttp" "tailscale.com/control/controlhttp"
@ -38,15 +39,32 @@ func TestNoiseVersion(t *testing.T) {
} }
} }
type noiseClientTest struct {
sendEarlyPayload bool
}
func TestNoiseClientHTTP2Upgrade(t *testing.T) { func TestNoiseClientHTTP2Upgrade(t *testing.T) {
noiseClientTest{}.run(t)
}
func TestNoiseClientHTTP2Upgrade_earlyPayload(t *testing.T) {
noiseClientTest{
sendEarlyPayload: true,
}.run(t)
}
func (tt noiseClientTest) run(t *testing.T) {
serverPrivate := key.NewMachine() serverPrivate := key.NewMachine()
clientPrivate := key.NewMachine() clientPrivate := key.NewMachine()
chalPrivate := key.NewChallenge()
const msg = "Hello, client" const msg = "Hello, client"
h2 := &http2.Server{} h2 := &http2.Server{}
hs := httptest.NewServer(&Upgrader{ hs := httptest.NewServer(&Upgrader{
h2srv: h2, h2srv: h2,
noiseKeyPriv: serverPrivate, noiseKeyPriv: serverPrivate,
sendEarlyPayload: tt.sendEarlyPayload,
challenge: chalPrivate,
httpBaseConfig: &http.Server{ httpBaseConfig: &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
@ -61,19 +79,56 @@ func TestNoiseClientHTTP2Upgrade(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
res, err := nc.post(context.Background(), "/", nil)
// Get a conn and verify it read its early payload before the http/2
// handshake.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
c, err := nc.getConn(ctx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer res.Body.Close() select {
all, err := io.ReadAll(res.Body) case <-c.earlyPayloadReady:
gotNonNil := c.earlyPayload != nil
if gotNonNil != tt.sendEarlyPayload {
t.Errorf("sendEarlyPayload = %v but got earlyPayload = %T", tt.sendEarlyPayload, c.earlyPayload)
}
if c.earlyPayload != nil {
if c.earlyPayload.NodeKeyChallenge != chalPrivate.Public() {
t.Errorf("earlyPayload.NodeKeyChallenge = %v; want %v", c.earlyPayload.NodeKeyChallenge, chalPrivate.Public())
}
}
case <-ctx.Done():
t.Fatal("timed out waiting for didReadHeaderCh")
}
checkRes := func(t *testing.T, res *http.Response) {
t.Helper()
defer res.Body.Close()
all, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if string(all) != msg {
t.Errorf("got response %q; want %q", all, msg)
}
}
// And verify we can do HTTP/2 against that conn.
res, err := (&http.Client{Transport: c}).Get("https://unused.example/")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if string(all) != msg { checkRes(t, res)
t.Errorf("got response %q; want %q", all, msg)
}
// And try using the high-level nc.post API as well.
res, err = nc.post(context.Background(), "/", nil)
if err != nil {
t.Fatal(err)
}
checkRes(t, res)
} }
// Upgrader is an http.Handler that hijacks and upgrades POST-with-Upgrade // Upgrader is an http.Handler that hijacks and upgrades POST-with-Upgrade
@ -91,6 +146,7 @@ type Upgrader struct {
logf logger.Logf logf logger.Logf
noiseKeyPriv key.MachinePrivate noiseKeyPriv key.MachinePrivate
challenge key.ChallengePrivate
sendEarlyPayload bool sendEarlyPayload bool
} }
@ -109,21 +165,21 @@ func (up *Upgrader) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
chalPub := key.NewChallenge()
earlyWriteFn := func(protocolVersion int, w io.Writer) error { earlyWriteFn := func(protocolVersion int, w io.Writer) error {
if !up.sendEarlyPayload { if !up.sendEarlyPayload {
return nil return nil
} }
earlyJSON, err := json.Marshal(struct { earlyJSON, err := json.Marshal(&tailcfg.EarlyNoise{
NodeKeyOwnershipChallenge string NodeKeyChallenge: up.challenge.Public(),
}{chalPub.Public().String()}) })
if err != nil { if err != nil {
return err return err
} }
// 5 bytes that won't be mistaken for an HTTP/2 frame: // 5 bytes that won't be mistaken for an HTTP/2 frame:
// https://httpwg.org/specs/rfc7540.html#rfc.section.4.1 (Especially not // https://httpwg.org/specs/rfc7540.html#rfc.section.4.1 (Especially not
// an HTTP/2 settings frame, which isn't of type 'T') // an HTTP/2 settings frame, which isn't of type 'T')
var notH2Frame = [5]byte{0xff, 0xff, 0xff, 'T', 'S'} var notH2Frame [5]byte
copy(notH2Frame[:], earlyPayloadMagic)
var lenBuf [4]byte var lenBuf [4]byte
binary.BigEndian.PutUint32(lenBuf[:], uint32(len(earlyJSON))) binary.BigEndian.PutUint32(lenBuf[:], uint32(len(earlyJSON)))
// These writes are all buffered by caller, so fine to do them // These writes are all buffered by caller, so fine to do them

@ -1946,3 +1946,15 @@ type PeerChange struct {
// //
// Mnemonic: 3.3.40 are numbers above the keys D, E, R, P. // Mnemonic: 3.3.40 are numbers above the keys D, E, R, P.
const DerpMagicIP = "127.3.3.40" const DerpMagicIP = "127.3.3.40"
// EarlyNoise is the early payload that's sent over Noise but before the HTTP/2
// handshake when connecting to the coordination server.
//
// This exists to let the server push some early info to client for that
// stateful HTTP/2+Noise connection without incurring an extra round trip. (This
// would've used HTTP/2 server push, had Go's client-side APIs been available)
type EarlyNoise struct {
// NodeKeyChallenge is a random per-connection public key to be used by
// the client to prove possession of a wireguard private key.
NodeKeyChallenge key.ChallengePublic `json:"nodeKeyChallenge"`
}

Loading…
Cancel
Save