ssh/tailssh: handle dialing multiple recorders and failing open

This adds support to try dialing out to multiple recorders each
with a 5s timeout and an overall 30s timeout. It also starts respecting
the actions `OnRecordingFailure` field if set, if it is not set
it fails open.

Updates tailscale/corp#9967

Signed-off-by: Maisem Ali <maisem@tailscale.com>
pull/7950/head
Maisem Ali 2 years ago committed by Maisem Ali
parent f66ddb544c
commit 7778d708a6

@ -17,6 +17,7 @@ import (
"io"
"net"
"net/http"
"net/http/httptrace"
"net/netip"
"net/url"
"os"
@ -42,6 +43,7 @@ import (
"tailscale.com/types/netmap"
"tailscale.com/util/clientmetric"
"tailscale.com/util/mak"
"tailscale.com/util/multierr"
"tailscale.com/version/distro"
)
@ -79,33 +81,11 @@ type server struct {
// mu protects the following
mu sync.Mutex
httpc *http.Client // for calling out to peers.
activeConns map[*conn]bool // set; value is always true
fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL
shutdownCalled bool
}
// sessionRecordingClient returns an http.Client that uses srv.lb.Dialer() to
// dial connections. This is used to make requests to the session recording
// server to upload session recordings.
func (srv *server) sessionRecordingClient() *http.Client {
srv.mu.Lock()
defer srv.mu.Unlock()
if srv.httpc != nil {
return srv.httpc
}
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
return srv.lb.Dialer().UserDial(ctx, network, addr)
}
srv.httpc = &http.Client{
Transport: tr,
}
return srv.httpc
}
func (srv *server) now() time.Time {
if srv != nil && srv.timeNow != nil {
return srv.timeNow()
@ -1078,7 +1058,7 @@ func (ss *sshSession) run() {
if err != nil {
var uve userVisibleError
if errors.As(err, &uve) {
fmt.Fprintf(ss, "%s\r\n", uve)
fmt.Fprintf(ss, "%s\r\n", uve.SSHTerminationMessage())
} else {
fmt.Fprintf(ss, "can't start new recording\r\n")
}
@ -1086,9 +1066,11 @@ func (ss *sshSession) run() {
ss.Exit(1)
return
}
if rec != nil {
defer rec.Close()
}
}
}
err := ss.launchProcess()
if err != nil {
@ -1169,15 +1151,16 @@ func (ss *sshSession) run() {
// If the final action has a non-empty list of recorders, that list is
// returned. Otherwise, the list of recorders from the initial action
// is returned.
func (ss *sshSession) recorders() []netip.AddrPort {
func (ss *sshSession) recorders() ([]netip.AddrPort, *tailcfg.SSHRecorderFailureAction) {
if len(ss.conn.finalAction.Recorders) > 0 {
return ss.conn.finalAction.Recorders
return ss.conn.finalAction.Recorders, ss.conn.finalAction.OnRecordingFailure
}
return ss.conn.action0.Recorders
return ss.conn.action0.Recorders, ss.conn.action0.OnRecordingFailure
}
func (ss *sshSession) shouldRecord() bool {
return len(ss.recorders()) > 0
recs, _ := ss.recorders()
return len(recs) > 0
}
type sshConnInfo struct {
@ -1409,16 +1392,120 @@ type CastHeader struct {
LocalUser string `json:"localUser"`
}
// sessionRecordingClient returns an http.Client that uses srv.lb.Dialer() to
// dial connections. This is used to make requests to the session recording
// server to upload session recordings.
// It uses the provided dialCtx to dial connections, and limits a single dial
// to 5 seconds.
func (ss *sshSession) sessionRecordingClient(dialCtx context.Context) (*http.Client, error) {
dialer := ss.conn.srv.lb.Dialer()
if dialer == nil {
return nil, errors.New("no peer API transport")
}
tr := dialer.PeerAPITransport().Clone()
dialContextFn := tr.DialContext
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
perAttemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
go func() {
select {
case <-perAttemptCtx.Done():
case <-dialCtx.Done():
cancel()
}
}()
return dialContextFn(perAttemptCtx, network, addr)
}
return &http.Client{
Transport: tr,
}, nil
}
// connectToRecorder connects to the recorder at any of the provided addresses.
// It returns the first successful response, or a multierr if all attempts fail.
//
// On success, it returns a WriteCloser that can be used to upload the
// recording, and a channel that will be sent an error (or nil) when the upload
// fails or completes.
func (ss *sshSession) connectToRecorder(ctx context.Context, recs []netip.AddrPort) (io.WriteCloser, <-chan error, error) {
if len(recs) == 0 {
return nil, nil, errors.New("no recorders configured")
}
// We use a special context for dialing the recorder, so that we can
// limit the time we spend dialing to 30 seconds and still have an
// unbounded context for the upload.
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
defer dialCancel()
hc, err := ss.sessionRecordingClient(dialCtx)
if err != nil {
return nil, nil, err
}
var errs []error
for _, ap := range recs {
// We dial the recorder and wait for it to send a 100-continue
// response before returning from this function. This ensures that
// the recorder is ready to accept the recording.
// got100 is closed when we receive the 100-continue response.
got100 := make(chan struct{})
ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
Got100Continue: func() {
close(got100)
},
})
pr, pw := io.Pipe()
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s:%d/record", ap.Addr(), ap.Port()), pr)
if err != nil {
errs = append(errs, fmt.Errorf("recording: error starting recording: %w", err))
continue
}
// We set the Expect header to 100-continue, so that the recorder
// will send a 100-continue response before it starts reading the
// request body.
req.Header.Set("Expect", "100-continue")
// errChan is used to indicate the result of the request.
errChan := make(chan error, 1)
go func() {
resp, err := hc.Do(req)
if err != nil {
errChan <- fmt.Errorf("recording: error starting recording: %w", err)
return
}
if resp.StatusCode != 200 {
errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status)
return
}
errChan <- nil
}()
select {
case <-got100:
case err := <-errChan:
// If we get an error before we get the 100-continue response,
// we need to try another recorder.
if err == nil {
// If the error is nil, we got a 200 response, which
// is unexpected as we haven't sent any data yet.
err = errors.New("recording: unexpected EOF")
}
errs = append(errs, err)
continue
}
return pw, errChan, nil
}
return nil, nil, multierr.New(errs...)
}
// startNewRecording starts a new SSH session recording.
// It may return a nil recording if recording is not available.
func (ss *sshSession) startNewRecording() (_ *recording, err error) {
recorders := ss.recorders()
recorders, onFailure := ss.recorders()
if len(recorders) == 0 {
return nil, errors.New("no recorders configured")
}
recorder := recorders[0]
if len(recorders) > 1 {
ss.logf("warning: multiple recorders configured, using first one: %v", recorder)
}
var w ssh.Window
if ptyReq, _, isPtyReq := ss.Pty(); isPtyReq {
@ -1436,51 +1523,43 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
start: now,
}
pr, pw := io.Pipe()
// We want to use a background context for uploading and not ss.ctx.
// ss.ctx is closed when the session closes, but we don't want to break the upload at that time.
// Instead we want to wait for the session to close the writer when it finishes.
ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s:%d/record", recorder.Addr(), recorder.Port()), pr)
wc, errChan, err := ss.connectToRecorder(ctx, recorders)
if err != nil {
pr.Close()
pw.Close()
return nil, err
// TODO(catzkorn): notify control here.
if onFailure != nil && onFailure.RejectSessionWithMessage != "" {
ss.logf("recording: error starting recording (rejecting session): %v", err)
return nil, userVisibleError{
error: err,
msg: onFailure.RejectSessionWithMessage,
}
// We want to wait for the server to respond with 100 Continue to notifiy us
// that it's ready to receive data. We do this to block the session from
// starting until the server is ready to receive data.
// It also allows the server to reject the request before we start sending
// data.
req.Header.Set("Expect", "100-continue")
}
ss.logf("recording: error starting recording (failing open): %v", err)
return nil, nil
}
go func() {
defer pw.Close()
ss.logf("starting asciinema recording to %s", recorder)
hc := ss.conn.srv.sessionRecordingClient()
resp, err := hc.Do(req)
if err != nil {
err := fmt.Errorf("recording: error sending recording: %w", err)
ss.logf("%v", err)
ss.cancelCtx(userVisibleError{
msg: "recording: error sending recording",
error: err,
})
err := <-errChan
if err == nil {
// Success.
return
}
defer resp.Body.Close()
defer ss.cancelCtx(errors.New("recording: done"))
if resp.StatusCode != http.StatusOK {
err := fmt.Errorf("recording: server responded with %s", resp.Status)
ss.logf("%v", err)
// TODO(catzkorn): notify control here.
if onFailure != nil && onFailure.TerminateSessionWithMessage != "" {
ss.logf("recording: error uploading recording (closing session): %v", err)
ss.cancelCtx(userVisibleError{
msg: "recording server responded with: " + resp.Status,
error: err,
msg: onFailure.TerminateSessionWithMessage,
})
return
}
ss.logf("recording: error uploading recording (failing open): %v", err)
}()
rec.out = pw
rec.out = wc
ch := CastHeader{
Version: 2,
@ -1515,7 +1594,7 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
return nil, err
}
j = append(j, '\n')
if _, err := pw.Write(j); err != nil {
if _, err := rec.out.Write(j); err != nil {
if errors.Is(err, io.ErrClosedPipe) && ss.ctx.Err() != nil {
// If we got an io.ErrClosedPipe, it's likely because
// the recording server closed the connection on us. Return

@ -240,7 +240,7 @@ var (
)
func (ts *localState) Dialer() *tsdial.Dialer {
return nil
return &tsdial.Dialer{}
}
func (ts *localState) GetSSH_HostKeys() ([]gossh.Signer, error) {
@ -339,7 +339,6 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
s := &server{
logf: t.Logf,
httpc: recordingServer.Client(),
lb: &localState{
sshEnabled: true,
matchingRule: newSSHRule(
@ -348,6 +347,10 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
Recorders: []netip.AddrPort{
netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
},
OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
RejectSessionWithMessage: "session rejected",
TerminateSessionWithMessage: "session terminated",
},
},
),
},
@ -374,7 +377,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
w.WriteHeader(http.StatusForbidden)
},
sshCommand: "echo hello",
wantClientOutput: "recording: server responded with 403 Forbidden\r\n",
wantClientOutput: "session rejected\r\n",
clientOutputMustNotContain: []string{"hello"},
},
@ -386,7 +389,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
w.WriteHeader(http.StatusInternalServerError)
},
sshCommand: "echo hello && sleep 1 && echo world",
wantClientOutput: "\r\n\r\nrecording server responded with: 500 Internal Server Error\r\n\r\n",
wantClientOutput: "\r\n\r\nsession terminated\r\n\r\n",
clientOutputMustNotContain: []string{"world"},
},
@ -440,6 +443,103 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
}
}
func TestMultipleRecorders(t *testing.T) {
if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
}
done := make(chan struct{})
recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer close(done)
io.ReadAll(r.Body)
w.WriteHeader(http.StatusOK)
}))
defer recordingServer.Close()
badRecorder, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
badRecorderAddr := badRecorder.Addr().String()
badRecorder.Close()
badRecordingServer500 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(500)
}))
defer badRecordingServer500.Close()
badRecordingServer200 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer badRecordingServer200.Close()
s := &server{
logf: t.Logf,
lb: &localState{
sshEnabled: true,
matchingRule: newSSHRule(
&tailcfg.SSHAction{
Accept: true,
Recorders: []netip.AddrPort{
netip.MustParseAddrPort(badRecorderAddr),
netip.MustParseAddrPort(badRecordingServer500.Listener.Addr().String()),
netip.MustParseAddrPort(badRecordingServer200.Listener.Addr().String()),
netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
},
OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
RejectSessionWithMessage: "session rejected",
TerminateSessionWithMessage: "session terminated",
},
},
),
},
}
defer s.Shutdown()
src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
sc, dc := memnet.NewTCPConn(src, dst, 1024)
const sshUser = "alice"
cfg := &gossh.ClientConfig{
User: sshUser,
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
if err != nil {
t.Errorf("client: %v", err)
return
}
client := gossh.NewClient(c, chans, reqs)
defer client.Close()
session, err := client.NewSession()
if err != nil {
t.Errorf("client: %v", err)
return
}
defer session.Close()
t.Logf("client established session")
out, err := session.CombinedOutput("echo Ran echo!")
if err != nil {
t.Errorf("client: %v", err)
}
if string(out) != "Ran echo!\n" {
t.Errorf("client: unexpected output: %q", out)
}
}()
if err := s.HandleSSHConn(dc); err != nil {
t.Errorf("unexpected error: %v", err)
}
wg.Wait()
select {
case <-done:
case <-time.After(1 * time.Second):
t.Fatal("timed out waiting for recording")
}
}
// TestSSHRecordingNonInteractive tests that the SSH server records the SSH session
// when the client is not interactive (i.e. no PTY).
// It starts a local SSH server and a recording server. The recording server
@ -465,7 +565,6 @@ func TestSSHRecordingNonInteractive(t *testing.T) {
s := &server{
logf: logger.Discard,
httpc: recordingServer.Client(),
lb: &localState{
sshEnabled: true,
matchingRule: newSSHRule(
@ -474,6 +573,10 @@ func TestSSHRecordingNonInteractive(t *testing.T) {
Recorders: []netip.AddrPort{
must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())),
},
OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
RejectSessionWithMessage: "session rejected",
TerminateSessionWithMessage: "session terminated",
},
},
),
},

@ -97,7 +97,8 @@ type CapabilityVersion int
// - 58: 2023-03-10: Client retries lite map updates before restarting map poll.
// - 59: 2023-03-16: Client understands Peers[].SelfNodeV4MasqAddrForThisPeer
// - 60: 2023-04-06: Client understands IsWireGuardOnly
const CurrentCapabilityVersion CapabilityVersion = 60
// - 61: 2023-04-18: Client understand SSHAction.SSHRecorderFailureAction
const CurrentCapabilityVersion CapabilityVersion = 61
type StableID string

Loading…
Cancel
Save