control/controlclient: export NoiseClient

This allows reusing the NoiseClient in other repos without having to reimplement the earlyPayload logic.

Signed-off-by: Maisem Ali <maisem@tailscale.com>
pull/6188/head
Maisem Ali 2 years ago committed by Maisem Ali
parent d57cba8655
commit a413fa4f85

@ -84,8 +84,8 @@ type Direct struct {
serverKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key serverKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key
serverNoiseKey key.MachinePublic serverNoiseKey key.MachinePublic
sfGroup singleflight.Group[struct{}, *noiseClient] // protects noiseClient creation. sfGroup singleflight.Group[struct{}, *NoiseClient] // protects noiseClient creation.
noiseClient *noiseClient noiseClient *NoiseClient
persist persist.Persist persist persist.Persist
authKey string authKey string
@ -262,7 +262,7 @@ func NewDirect(opts Options) (*Direct, error) {
} }
} }
if opts.NoiseTestClient != nil { if opts.NoiseTestClient != nil {
c.noiseClient = &noiseClient{ c.noiseClient = &NoiseClient{
Client: opts.NoiseTestClient, Client: opts.NoiseTestClient,
} }
c.serverNoiseKey = key.NewMachine().Public() // prevent early error before hitting test client c.serverNoiseKey = key.NewMachine().Public() // prevent early error before hitting test client
@ -1470,7 +1470,7 @@ func sleepAsRequested(ctx context.Context, logf logger.Logf, timeoutReset chan<-
} }
// getNoiseClient returns the noise client, creating one if one doesn't exist. // getNoiseClient returns the noise client, creating one if one doesn't exist.
func (c *Direct) getNoiseClient() (*noiseClient, error) { func (c *Direct) getNoiseClient() (*NoiseClient, error) {
c.mu.Lock() c.mu.Lock()
serverNoiseKey := c.serverNoiseKey serverNoiseKey := c.serverNoiseKey
nc := c.noiseClient nc := c.noiseClient
@ -1485,13 +1485,13 @@ func (c *Direct) getNoiseClient() (*noiseClient, error) {
if c.dialPlan != nil { if c.dialPlan != nil {
dp = c.dialPlan.Load dp = c.dialPlan.Load
} }
nc, err, _ := c.sfGroup.Do(struct{}{}, func() (*noiseClient, error) { nc, err, _ := c.sfGroup.Do(struct{}{}, func() (*NoiseClient, error) {
k, err := c.getMachinePrivKey() k, err := c.getMachinePrivKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.logf("creating new noise client") c.logf("creating new noise client")
nc, err := newNoiseClient(k, serverNoiseKey, c.serverURL, c.dialer, dp) nc, err := NewNoiseClient(k, serverNoiseKey, c.serverURL, c.dialer, dp)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1618,20 +1618,7 @@ func (c *Direct) GetSingleUseNoiseRoundTripper(ctx context.Context) (http.RoundT
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
for tries := 0; tries < 3; tries++ { return nc.GetSingleUseRoundTripper(ctx)
conn, err := nc.getConn(ctx)
if err != nil {
return nil, nil, err
}
earlyPayloadMaybeNil, err := conn.getEarlyPayload(ctx)
if err != nil {
return nil, nil, err
}
if conn.h2cc.ReserveNewRequest() {
return conn, earlyPayloadMaybeNil, nil
}
}
return nil, nil, errors.New("[unexpected] failed to reserve a request on a connection")
} }
// doPingerPing sends a Ping to pr.IP using pinger, and sends an http request back to // doPingerPing sends a Ping to pr.IP using pinger, and sends an http request back to

@ -35,7 +35,7 @@ import (
type noiseConn struct { type noiseConn struct {
*controlbase.Conn *controlbase.Conn
id int id int
pool *noiseClient pool *NoiseClient
h2cc *http2.ClientConn h2cc *http2.ClientConn
readHeaderOnce sync.Once // guards init of reader field readHeaderOnce sync.Once // guards init of reader field
@ -135,9 +135,9 @@ func (c *noiseConn) Close() error {
return nil return nil
} }
// noiseClient provides a http.Client to connect to tailcontrol over // NoiseClient provides a http.Client to connect to tailcontrol over
// the ts2021 protocol. // the ts2021 protocol.
type noiseClient struct { type NoiseClient struct {
// Client is an HTTP client to talk to the coordination server. // Client is an HTTP client to talk to the coordination server.
// It automatically makes a new Noise connection as needed. // It automatically makes a new Noise connection as needed.
// It does not support node key proofs. To do that, call // It does not support node key proofs. To do that, call
@ -175,11 +175,11 @@ type noiseClient struct {
connPool map[int]*noiseConn // active connections not yet closed; see noiseConn.Close connPool map[int]*noiseConn // active connections not yet closed; see noiseConn.Close
} }
// newNoiseClient returns a new noiseClient for the provided server and machine key. // NewNoiseClient returns a new noiseClient for the provided server and machine key.
// serverURL is of the form https://<host>:<port> (no trailing slash). // serverURL is of the form https://<host>:<port> (no trailing slash).
// //
// dialPlan may be nil // dialPlan may be nil
func newNoiseClient(privKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer, dialPlan func() *tailcfg.ControlDialPlan) (*noiseClient, error) { func NewNoiseClient(privKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer, dialPlan func() *tailcfg.ControlDialPlan) (*NoiseClient, error) {
u, err := url.Parse(serverURL) u, err := url.Parse(serverURL)
if err != nil { if err != nil {
return nil, err return nil, err
@ -200,7 +200,7 @@ func newNoiseClient(privKey key.MachinePrivate, serverPubKey key.MachinePublic,
httpPort = "80" httpPort = "80"
httpsPort = "443" httpsPort = "443"
} }
np := &noiseClient{ np := &NoiseClient{
serverPubKey: serverPubKey, serverPubKey: serverPubKey,
privKey: privKey, privKey: privKey,
host: u.Hostname(), host: u.Hostname(),
@ -227,7 +227,30 @@ func newNoiseClient(privKey key.MachinePrivate, serverPubKey key.MachinePublic,
return np, nil return np, nil
} }
func (nc *noiseClient) getConn(ctx context.Context) (*noiseConn, error) { // GetSingleUseRoundTripper returns a RoundTripper that can be only be used once
// (and must be used once) to make a single HTTP request over the noise channel
// to the coordination server.
//
// In addition to the RoundTripper, it returns the HTTP/2 channel's early noise
// payload, if any.
func (nc *NoiseClient) GetSingleUseRoundTripper(ctx context.Context) (http.RoundTripper, *tailcfg.EarlyNoise, error) {
for tries := 0; tries < 3; tries++ {
conn, err := nc.getConn(ctx)
if err != nil {
return nil, nil, err
}
earlyPayloadMaybeNil, err := conn.getEarlyPayload(ctx)
if err != nil {
return nil, nil, err
}
if conn.h2cc.ReserveNewRequest() {
return conn, earlyPayloadMaybeNil, nil
}
}
return nil, nil, errors.New("[unexpected] failed to reserve a request on a connection")
}
func (nc *NoiseClient) getConn(ctx context.Context) (*noiseConn, error) {
nc.mu.Lock() nc.mu.Lock()
if last := nc.last; last != nil && last.canTakeNewRequest() { if last := nc.last; last != nil && last.canTakeNewRequest() {
nc.mu.Unlock() nc.mu.Unlock()
@ -242,7 +265,7 @@ func (nc *noiseClient) getConn(ctx context.Context) (*noiseConn, error) {
return conn, nil return conn, nil
} }
func (nc *noiseClient) RoundTrip(req *http.Request) (*http.Response, error) { func (nc *NoiseClient) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := req.Context() ctx := req.Context()
conn, err := nc.getConn(ctx) conn, err := nc.getConn(ctx)
if err != nil { if err != nil {
@ -253,7 +276,7 @@ func (nc *noiseClient) RoundTrip(req *http.Request) (*http.Response, error) {
// connClosed removes the connection with the provided ID from the pool // connClosed removes the connection with the provided ID from the pool
// of active connections. // of active connections.
func (nc *noiseClient) connClosed(id int) { func (nc *NoiseClient) connClosed(id int) {
nc.mu.Lock() nc.mu.Lock()
defer nc.mu.Unlock() defer nc.mu.Unlock()
conn := nc.connPool[id] conn := nc.connPool[id]
@ -267,7 +290,7 @@ func (nc *noiseClient) connClosed(id int) {
// Close closes all the underlying noise connections. // Close closes all the underlying noise connections.
// It is a no-op and returns nil if the connection is already closed. // It is a no-op and returns nil if the connection is already closed.
func (nc *noiseClient) Close() error { func (nc *NoiseClient) Close() error {
nc.mu.Lock() nc.mu.Lock()
conns := nc.connPool conns := nc.connPool
nc.connPool = nil nc.connPool = nil
@ -284,7 +307,7 @@ func (nc *noiseClient) Close() error {
// dial opens a new connection to tailcontrol, fetching the server noise key // dial opens a new connection to tailcontrol, fetching the server noise key
// if not cached. // if not cached.
func (nc *noiseClient) dial() (*noiseConn, error) { func (nc *NoiseClient) dial() (*noiseConn, error) {
nc.mu.Lock() nc.mu.Lock()
connID := nc.nextID connID := nc.nextID
nc.nextID++ nc.nextID++
@ -369,7 +392,7 @@ func (nc *noiseClient) dial() (*noiseConn, error) {
return ncc, nil return ncc, nil
} }
func (nc *noiseClient) post(ctx context.Context, path string, body any) (*http.Response, error) { func (nc *NoiseClient) post(ctx context.Context, path string, body any) (*http.Response, error) {
jbody, err := json.Marshal(body) jbody, err := json.Marshal(body)
if err != nil { if err != nil {
return nil, err return nil, err

@ -75,7 +75,7 @@ func (tt noiseClientTest) run(t *testing.T) {
defer hs.Close() defer hs.Close()
dialer := new(tsdial.Dialer) dialer := new(tsdial.Dialer)
nc, err := newNoiseClient(clientPrivate, serverPrivate.Public(), hs.URL, dialer, nil) nc, err := NewNoiseClient(clientPrivate, serverPrivate.Public(), hs.URL, dialer, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

Loading…
Cancel
Save