From ecf6cdd830be2699dcb063996ab82c366cf4bfaf Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Sat, 8 Oct 2022 17:54:53 -0700 Subject: [PATCH] ssh/tailssh: add TestSSHAuthFlow Signed-off-by: Maisem Ali --- net/nettest/conn.go | 21 ++- ssh/tailssh/tailssh.go | 33 +++-- ssh/tailssh/tailssh_test.go | 260 +++++++++++++++++++++++++++++++++++- 3 files changed, 299 insertions(+), 15 deletions(-) diff --git a/net/nettest/conn.go b/net/nettest/conn.go index 90727c4a8..5e26f8e41 100644 --- a/net/nettest/conn.go +++ b/net/nettest/conn.go @@ -6,6 +6,7 @@ package nettest import ( "net" + "net/netip" "time" ) @@ -32,20 +33,38 @@ func NewConn(name string, maxBuf int) (Conn, Conn) { return &connHalf{r: r, w: w}, &connHalf{r: w, w: r} } +// NewTCPConn creates a pair of Conns that are wired together by pipes. +func NewTCPConn(src, dst netip.AddrPort, maxBuf int) (local Conn, remote Conn) { + r := NewPipe(src.String(), maxBuf) + w := NewPipe(dst.String(), maxBuf) + + lAddr := net.TCPAddrFromAddrPort(src) + rAddr := net.TCPAddrFromAddrPort(dst) + + return &connHalf{r: r, w: w, remote: rAddr, local: lAddr}, &connHalf{r: w, w: r, remote: lAddr, local: rAddr} +} + type connAddr string func (a connAddr) Network() string { return "mem" } func (a connAddr) String() string { return string(a) } type connHalf struct { - r, w *Pipe + local, remote net.Addr + r, w *Pipe } func (c *connHalf) LocalAddr() net.Addr { + if c.local != nil { + return c.local + } return connAddr(c.r.name) } func (c *connHalf) RemoteAddr() net.Addr { + if c.remote != nil { + return c.remote + } return connAddr(c.w.name) } diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index ec3889c20..8d1b968d6 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -39,6 +39,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tempfork/gliderlabs/ssh" "tailscale.com/types/logger" + "tailscale.com/types/netmap" "tailscale.com/util/clientmetric" "tailscale.com/util/mak" ) @@ -47,8 +48,19 @@ var ( sshVerboseLogging = envknob.RegisterBool("TS_DEBUG_SSH_VLOG") ) +// ipnLocalBackend is the subset of ipnlocal.LocalBackend that we use. +// It is used for testing. +type ipnLocalBackend interface { + GetSSH_HostKeys() ([]gossh.Signer, error) + ShouldRunSSH() bool + NetMap() *netmap.NetworkMap + WhoIs(ipp netip.AddrPort) (n *tailcfg.Node, u tailcfg.UserProfile, ok bool) + DoNoiseRequest(req *http.Request) (*http.Response, error) + TailscaleVarRoot() string +} + type server struct { - lb *ipnlocal.LocalBackend + lb ipnLocalBackend logf logger.Logf tailscaledPath string @@ -212,7 +224,10 @@ func (c *conn) logf(format string, args ...any) { c.srv.logf(format, args...) } -// isAuthorized returns nil if the connection is authorized to proceed. +// isAuthorized walks through the action chain and returns nil if the connection +// is authorized. If the connection is not authorized, it returns +// gossh.ErrDenied. If the action chain resolution fails, it returns the +// resolution error. func (c *conn) isAuthorized(ctx ssh.Context) error { action := c.currentAction for { @@ -525,7 +540,7 @@ func (c *conn) setInfo(ctx ssh.Context) error { return fmt.Errorf("unknown Tailscale identity from src %v", ci.src) } ci.node = node - ci.uprof = &uprof + ci.uprof = uprof c.idH = ctx.SessionID() c.info = ci @@ -743,12 +758,8 @@ func (c *conn) expandPublicKeyURL(pubKeyURL string) string { if !strings.Contains(pubKeyURL, "$") { return pubKeyURL } - var localPart string - var loginName string - if c.info.uprof != nil { - loginName = c.info.uprof.LoginName - localPart, _, _ = strings.Cut(loginName, "@") - } + loginName := c.info.uprof.LoginName + localPart, _, _ := strings.Cut(loginName, "@") return strings.NewReplacer( "$LOGINNAME_EMAIL", loginName, "$LOGINNAME_LOCALPART", localPart, @@ -1108,7 +1119,7 @@ type sshConnInfo struct { node *tailcfg.Node // uprof is node's UserProfile. - uprof *tailcfg.UserProfile + uprof tailcfg.UserProfile } func (ci *sshConnInfo) String() string { @@ -1223,7 +1234,7 @@ func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool { return true } } - if p.UserLogin != "" && ci.uprof != nil && ci.uprof.LoginName == p.UserLogin { + if p.UserLogin != "" && ci.uprof.LoginName == p.UserLogin { return true } return false diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 3e7b1eb36..c83d57fd2 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -9,7 +9,10 @@ package tailssh import ( "bytes" + "crypto/ed25519" + "crypto/rand" "crypto/sha256" + "encoding/json" "errors" "fmt" "io" @@ -21,20 +24,27 @@ import ( "os/exec" "os/user" "reflect" + "runtime" "strings" + "sync" "sync/atomic" "testing" "time" + gossh "github.com/tailscale/golang-x-crypto/ssh" "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/store/mem" + "tailscale.com/net/nettest" "tailscale.com/net/tsdial" "tailscale.com/tailcfg" "tailscale.com/tempfork/gliderlabs/ssh" "tailscale.com/tstest" "tailscale.com/types/logger" + "tailscale.com/types/netmap" "tailscale.com/util/cibuild" "tailscale.com/util/lineread" + "tailscale.com/util/must" + "tailscale.com/util/strs" "tailscale.com/wgengine" ) @@ -173,7 +183,7 @@ func TestMatchRule(t *testing.T) { Principals: []*tailcfg.SSHPrincipal{{UserLogin: "foo@bar.com"}}, SSHUsers: map[string]string{"*": "ubuntu"}, }, - ci: &sshConnInfo{uprof: &tailcfg.UserProfile{LoginName: "foo@bar.com"}}, + ci: &sshConnInfo{uprof: tailcfg.UserProfile{LoginName: "foo@bar.com"}}, wantUser: "ubuntu", }, { @@ -211,6 +221,250 @@ func TestMatchRule(t *testing.T) { func timePtr(t time.Time) *time.Time { return &t } +// localState implements ipnLocalBackend for testing. +type localState struct { + sshEnabled bool + matchingRule *tailcfg.SSHRule + + // serverActions is a map of the action name to the action. + // It is served for paths like https://unused/ssh-action/. + // The action name is the last part of the action URL. + serverActions map[string]*tailcfg.SSHAction +} + +var ( + currentUser = os.Getenv("USER") // Use the current user for the test. + testSigner gossh.Signer + testSignerOnce sync.Once +) + +func (ts *localState) GetSSH_HostKeys() ([]gossh.Signer, error) { + testSignerOnce.Do(func() { + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + s, err := gossh.NewSignerFromSigner(priv) + if err != nil { + panic(err) + } + testSigner = s + }) + return []gossh.Signer{testSigner}, nil +} + +func (ts *localState) ShouldRunSSH() bool { + return ts.sshEnabled +} + +func (ts *localState) NetMap() *netmap.NetworkMap { + var policy *tailcfg.SSHPolicy + if ts.matchingRule != nil { + policy = &tailcfg.SSHPolicy{ + Rules: []*tailcfg.SSHRule{ + ts.matchingRule, + }, + } + } + + return &netmap.NetworkMap{ + SelfNode: &tailcfg.Node{ + ID: 1, + }, + SSHPolicy: policy, + } +} + +func (ts *localState) WhoIs(ipp netip.AddrPort) (n *tailcfg.Node, u tailcfg.UserProfile, ok bool) { + return &tailcfg.Node{ + ID: 2, + StableID: "peer-id", + }, tailcfg.UserProfile{ + LoginName: "peer", + }, true + +} + +func (ts *localState) DoNoiseRequest(req *http.Request) (*http.Response, error) { + rec := httptest.NewRecorder() + k, ok := strs.CutPrefix(req.URL.Path, "/ssh-action/") + if !ok { + rec.WriteHeader(http.StatusNotFound) + } + a, ok := ts.serverActions[k] + if !ok { + rec.WriteHeader(http.StatusNotFound) + return rec.Result(), nil + } + rec.WriteHeader(http.StatusOK) + if err := json.NewEncoder(rec).Encode(a); err != nil { + return nil, err + } + return rec.Result(), nil +} + +func (ts *localState) TailscaleVarRoot() string { + return "" +} + +func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule { + return &tailcfg.SSHRule{ + SSHUsers: map[string]string{ + "*": currentUser, + }, + Action: action, + Principals: []*tailcfg.SSHPrincipal{ + { + Any: true, + }, + }, + } +} + +func TestSSHAuthFlow(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("Not running on Linux, skipping") + } + acceptRule := newSSHRule(&tailcfg.SSHAction{ + Accept: true, + Message: "Welcome to Tailscale SSH!", + }) + rejectRule := newSSHRule(&tailcfg.SSHAction{ + Reject: true, + Message: "Go Away!", + }) + + tests := []struct { + name string + state *localState + wantBanner string + authErr bool + }{ + { + name: "no-policy", + state: &localState{ + sshEnabled: true, + }, + authErr: true, + }, + { + name: "accept", + state: &localState{ + sshEnabled: true, + matchingRule: acceptRule, + }, + wantBanner: "Welcome to Tailscale SSH!", + }, + { + name: "reject", + state: &localState{ + sshEnabled: true, + matchingRule: rejectRule, + }, + wantBanner: "Go Away!", + authErr: true, + }, + { + name: "simple-check", + state: &localState{ + sshEnabled: true, + matchingRule: newSSHRule(&tailcfg.SSHAction{ + HoldAndDelegate: "https://unused/ssh-action/accept", + }), + serverActions: map[string]*tailcfg.SSHAction{ + "accept": acceptRule.Action, + }, + }, + wantBanner: "Welcome to Tailscale SSH!", + }, + { + name: "multi-check", + state: &localState{ + sshEnabled: true, + matchingRule: newSSHRule(&tailcfg.SSHAction{ + HoldAndDelegate: "https://unused/ssh-action/check1", + }), + serverActions: map[string]*tailcfg.SSHAction{ + "check1": { + Message: "url-here", + HoldAndDelegate: "https://unused/ssh-action/check2", + }, + "check2": acceptRule.Action, + }, + }, + wantBanner: "url-here", + }, + { + name: "check-reject", + state: &localState{ + sshEnabled: true, + matchingRule: newSSHRule(&tailcfg.SSHAction{ + HoldAndDelegate: "https://unused/ssh-action/reject", + }), + serverActions: map[string]*tailcfg.SSHAction{ + "reject": rejectRule.Action, + }, + }, + wantBanner: "Go Away!", + authErr: true, + }, + } + s := &server{ + logf: logger.Discard, + } + defer s.Shutdown() + src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22")) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + sc, dc := nettest.NewTCPConn(src, dst, 1024) + s.lb = tc.state + cfg := &gossh.ClientConfig{ + User: "alice", + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + BannerCallback: func(message string) error { + if message != tc.wantBanner { + t.Errorf("BannerCallback = %q; want %q", message, tc.wantBanner) + } + return nil + }, + } + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) + if err != nil { + if !tc.authErr { + t.Errorf("client: %v", err) + } + return + } else if tc.authErr { + c.Close() + t.Errorf("client: expected error, got nil") + return + } + client := gossh.NewClient(c, chans, reqs) + defer client.Close() + session, err := client.NewSession() + if err != nil { + t.Errorf("client: %v", err) + return + } + defer session.Close() + o, err := session.CombinedOutput("echo Ran echo!") + if err != nil { + t.Errorf("client: %v", err) + } + t.Logf("output: %s", o) + }() + if err := s.HandleSSHConn(dc); err != nil { + t.Errorf("unexpected error: %v", err) + } + wg.Wait() + }) + } +} + func TestSSH(t *testing.T) { var logf logger.Logf = t.Logf eng, err := wgengine.NewFakeUserspaceEngine(logf, 0) @@ -249,7 +503,7 @@ func TestSSH(t *testing.T) { src: netip.MustParseAddrPort("1.2.3.4:32342"), dst: netip.MustParseAddrPort("1.2.3.5:22"), node: &tailcfg.Node{}, - uprof: &tailcfg.UserProfile{}, + uprof: tailcfg.UserProfile{}, } sc.finalAction = &tailcfg.SSHAction{Accept: true} @@ -428,7 +682,7 @@ func TestPublicKeyFetching(t *testing.T) { func TestExpandPublicKeyURL(t *testing.T) { c := &conn{ info: &sshConnInfo{ - uprof: &tailcfg.UserProfile{ + uprof: tailcfg.UserProfile{ LoginName: "bar@baz.tld", }, },