From 93221b45358eebdc6aea1f47826cb3cb996eae79 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sat, 16 Apr 2022 21:49:22 -0700 Subject: [PATCH] ssh/tailssh: cache public keys fetched from URLs Updates #3802 Change-Id: I96715bae02bce6ea19f16b1736d1bbcd7bcf3534 Signed-off-by: Brad Fitzpatrick --- ssh/tailssh/tailssh.go | 126 ++++++++++++++++++++++++++++++------ ssh/tailssh/tailssh_test.go | 67 +++++++++++++++++++ 2 files changed, 172 insertions(+), 21 deletions(-) diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index c710bb74e..ae831f16d 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -53,10 +53,21 @@ type server struct { logf logger.Logf tailscaledPath string - // mu protects activeSessions. + pubKeyHTTPClient *http.Client // or nil for http.DefaultClient + timeNow func() time.Time // or nil for time.Now + + // mu protects the following mu sync.Mutex - activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => that session - activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session + activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => session + activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session + fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL +} + +func (srv *server) now() time.Time { + if srv.timeNow != nil { + return srv.timeNow() + } + return time.Now() } func init() { @@ -264,7 +275,7 @@ func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr netaddr. return nil, nil, "", fmt.Errorf("unknown Tailscale identity from src %v", remoteAddr) } ci := &sshConnInfo{ - now: time.Now(), + now: srv.now(), fetchPublicKeysURL: srv.fetchPublicKeysURL, sshUser: sshUser, src: remoteAddr, @@ -280,11 +291,58 @@ func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr netaddr. return a, ci, localUser, nil } +// pubKeyCacheEntry is the cache value for an HTTPS URL of public keys (like +// "https://github.com/foo.keys") +type pubKeyCacheEntry struct { + lines []string + etag string // if sent by server + at time.Time +} + +const ( + pubKeyCacheDuration = time.Minute // how long to cache non-empty public keys + pubKeyCacheEmptyDuration = 15 * time.Second // how long to cache empty responses +) + +func (srv *server) fetchPublicKeysURLCached(url string) (ce pubKeyCacheEntry, ok bool) { + srv.mu.Lock() + defer srv.mu.Unlock() + // Mostly don't care about the size of this cache. Clean rarely. + if m := srv.fetchPublicKeysCache; len(m) > 50 { + tooOld := srv.now().Add(pubKeyCacheDuration * 10) + for k, ce := range m { + if ce.at.Before(tooOld) { + delete(m, k) + } + } + } + ce, ok = srv.fetchPublicKeysCache[url] + if !ok { + return ce, false + } + maxAge := pubKeyCacheDuration + if len(ce.lines) == 0 { + maxAge = pubKeyCacheEmptyDuration + } + return ce, srv.now().Sub(ce.at) < maxAge +} + +func (srv *server) pubKeyClient() *http.Client { + if srv.pubKeyHTTPClient != nil { + return srv.pubKeyHTTPClient + } + return http.DefaultClient +} + func (srv *server) fetchPublicKeysURL(url string) ([]string, error) { if !strings.HasPrefix(url, "https://") { return nil, errors.New("invalid URL scheme") } - // TODO(bradfitz): add caching + + ce, ok := srv.fetchPublicKeysURLCached(url) + if ok { + return ce.lines, nil + } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -292,16 +350,40 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) { if err != nil { return nil, err } - res, err := http.DefaultClient.Do(req) + if ce.etag != "" { + req.Header.Add("If-None-Match", ce.etag) + } + res, err := srv.pubKeyClient().Do(req) if err != nil { return nil, err } defer res.Body.Close() - if res.StatusCode != http.StatusOK { - return nil, errors.New(res.Status) + var lines []string + var etag string + switch res.StatusCode { + default: + err = fmt.Errorf("unexpected status %v", res.Status) + srv.logf("fetching public keys from %s: %v", url, err) + case http.StatusNotModified: + lines = ce.lines + etag = ce.etag + case http.StatusOK: + var all []byte + all, err = io.ReadAll(io.LimitReader(res.Body, 4<<10)) + if s := strings.TrimSpace(string(all)); s != "" { + lines = strings.Split(s, "\n") + } + etag = res.Header.Get("Etag") } - all, err := io.ReadAll(io.LimitReader(res.Body, 4<<10)) - return strings.Split(string(all), "\n"), err + + srv.mu.Lock() + defer srv.mu.Unlock() + mapSet(&srv.fetchPublicKeysCache, url, pubKeyCacheEntry{ + at: srv.now(), + lines: lines, + etag: etag, + }) + return lines, err } // handleSSH is invoked when a new SSH connection attempt is made. @@ -523,26 +605,20 @@ func (srv *server) getSessionForContext(sctx ssh.Context) (ss *sshSession, ok bo func (srv *server) startSession(ss *sshSession) { srv.mu.Lock() defer srv.mu.Unlock() - if srv.activeSessionByH == nil { - srv.activeSessionByH = make(map[string]*sshSession) - } - if srv.activeSessionBySharedID == nil { - srv.activeSessionBySharedID = make(map[string]*sshSession) - } if ss.idH == "" { panic("empty idH") } - if _, dup := srv.activeSessionByH[ss.idH]; dup { - panic("dup idH") - } if ss.sharedID == "" { panic("empty sharedID") } + if _, dup := srv.activeSessionByH[ss.idH]; dup { + panic("dup idH") + } if _, dup := srv.activeSessionBySharedID[ss.sharedID]; dup { panic("dup sharedID") } - srv.activeSessionByH[ss.idH] = ss - srv.activeSessionBySharedID[ss.sharedID] = ss + mapSet(&srv.activeSessionByH, ss.idH, ss) + mapSet(&srv.activeSessionBySharedID, ss.sharedID, ss) } // endSession unregisters s from the list of active sessions. @@ -1057,3 +1133,11 @@ func envEq(a, b string) bool { } return a == b } + +// mapSet assigns m[k] = v, making m if necessary. +func mapSet[K comparable, V any](m *map[K]V, k K, v V) { + if *m == nil { + *m = make(map[K]V) + } + (*m)[k] = v +} diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 23e2540a3..7fed58c05 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -9,13 +9,19 @@ package tailssh import ( "bytes" + "crypto/sha256" "errors" "fmt" + "io" "net" + "net/http" + "net/http/httptest" "os" "os/exec" "os/user" + "reflect" "strings" + "sync/atomic" "testing" "time" @@ -25,6 +31,7 @@ import ( "tailscale.com/net/tsdial" "tailscale.com/tailcfg" "tailscale.com/tempfork/gliderlabs/ssh" + "tailscale.com/tstest" "tailscale.com/types/logger" "tailscale.com/util/cibuild" "tailscale.com/util/lineread" @@ -336,3 +343,63 @@ func parseEnv(out []byte) map[string]string { }) return e } + +func TestPublicKeyFetching(t *testing.T) { + var reqsTotal, reqsIfNoneMatchHit, reqsIfNoneMatchMiss int32 + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32((&reqsTotal), 1) + etag := fmt.Sprintf("W/%q", sha256.Sum256([]byte(r.URL.Path))) + w.Header().Set("Etag", etag) + if v := r.Header.Get("If-None-Match"); v != "" { + if v == etag { + atomic.AddInt32(&reqsIfNoneMatchHit, 1) + w.WriteHeader(304) + return + } + atomic.AddInt32(&reqsIfNoneMatchMiss, 1) + } + io.WriteString(w, "foo\nbar\n"+string(r.URL.Path)+"\n") + })) + ts.StartTLS() + defer ts.Close() + keys := ts.URL + + clock := &tstest.Clock{} + srv := &server{ + pubKeyHTTPClient: ts.Client(), + timeNow: clock.Now, + } + for i := 0; i < 2; i++ { + got, err := srv.fetchPublicKeysURL(keys + "/alice.keys") + if err != nil { + t.Fatal(err) + } + if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) { + t.Errorf("got %q; want %q", got, want) + } + } + if got, want := atomic.LoadInt32(&reqsTotal), int32(1); got != want { + t.Errorf("got %d requests; want %d", got, want) + } + if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(0); got != want { + t.Errorf("got %d etag hits; want %d", got, want) + } + clock.Advance(5 * time.Minute) + got, err := srv.fetchPublicKeysURL(keys + "/alice.keys") + if err != nil { + t.Fatal(err) + } + if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) { + t.Errorf("got %q; want %q", got, want) + } + if got, want := atomic.LoadInt32(&reqsTotal), int32(2); got != want { + t.Errorf("got %d requests; want %d", got, want) + } + if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(1); got != want { + t.Errorf("got %d etag hits; want %d", got, want) + } + if got, want := atomic.LoadInt32(&reqsIfNoneMatchMiss), int32(0); got != want { + t.Errorf("got %d etag misses; want %d", got, want) + } + +}