ssh/tailssh: cache public keys fetched from URLs

Updates #3802

Change-Id: I96715bae02bce6ea19f16b1736d1bbcd7bcf3534
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/4431/head
Brad Fitzpatrick 3 years ago committed by Brad Fitzpatrick
parent 3ffd88a84a
commit 93221b4535

@ -53,10 +53,21 @@ type server struct {
logf logger.Logf logf logger.Logf
tailscaledPath string tailscaledPath string
// mu protects activeSessions. pubKeyHTTPClient *http.Client // or nil for http.DefaultClient
timeNow func() time.Time // or nil for time.Now
// mu protects the following
mu sync.Mutex mu sync.Mutex
activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => that session activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => session
activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session
fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL
}
func (srv *server) now() time.Time {
if srv.timeNow != nil {
return srv.timeNow()
}
return time.Now()
} }
func init() { func init() {
@ -264,7 +275,7 @@ func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr netaddr.
return nil, nil, "", fmt.Errorf("unknown Tailscale identity from src %v", remoteAddr) return nil, nil, "", fmt.Errorf("unknown Tailscale identity from src %v", remoteAddr)
} }
ci := &sshConnInfo{ ci := &sshConnInfo{
now: time.Now(), now: srv.now(),
fetchPublicKeysURL: srv.fetchPublicKeysURL, fetchPublicKeysURL: srv.fetchPublicKeysURL,
sshUser: sshUser, sshUser: sshUser,
src: remoteAddr, src: remoteAddr,
@ -280,11 +291,58 @@ func (srv *server) evaluatePolicy(sshUser string, localAddr, remoteAddr netaddr.
return a, ci, localUser, nil return a, ci, localUser, nil
} }
// pubKeyCacheEntry is the cache value for an HTTPS URL of public keys (like
// "https://github.com/foo.keys")
type pubKeyCacheEntry struct {
lines []string
etag string // if sent by server
at time.Time
}
const (
pubKeyCacheDuration = time.Minute // how long to cache non-empty public keys
pubKeyCacheEmptyDuration = 15 * time.Second // how long to cache empty responses
)
func (srv *server) fetchPublicKeysURLCached(url string) (ce pubKeyCacheEntry, ok bool) {
srv.mu.Lock()
defer srv.mu.Unlock()
// Mostly don't care about the size of this cache. Clean rarely.
if m := srv.fetchPublicKeysCache; len(m) > 50 {
tooOld := srv.now().Add(pubKeyCacheDuration * 10)
for k, ce := range m {
if ce.at.Before(tooOld) {
delete(m, k)
}
}
}
ce, ok = srv.fetchPublicKeysCache[url]
if !ok {
return ce, false
}
maxAge := pubKeyCacheDuration
if len(ce.lines) == 0 {
maxAge = pubKeyCacheEmptyDuration
}
return ce, srv.now().Sub(ce.at) < maxAge
}
func (srv *server) pubKeyClient() *http.Client {
if srv.pubKeyHTTPClient != nil {
return srv.pubKeyHTTPClient
}
return http.DefaultClient
}
func (srv *server) fetchPublicKeysURL(url string) ([]string, error) { func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
if !strings.HasPrefix(url, "https://") { if !strings.HasPrefix(url, "https://") {
return nil, errors.New("invalid URL scheme") return nil, errors.New("invalid URL scheme")
} }
// TODO(bradfitz): add caching
ce, ok := srv.fetchPublicKeysURLCached(url)
if ok {
return ce.lines, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
@ -292,16 +350,40 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
res, err := http.DefaultClient.Do(req) if ce.etag != "" {
req.Header.Add("If-None-Match", ce.etag)
}
res, err := srv.pubKeyClient().Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer res.Body.Close() defer res.Body.Close()
if res.StatusCode != http.StatusOK { var lines []string
return nil, errors.New(res.Status) var etag string
switch res.StatusCode {
default:
err = fmt.Errorf("unexpected status %v", res.Status)
srv.logf("fetching public keys from %s: %v", url, err)
case http.StatusNotModified:
lines = ce.lines
etag = ce.etag
case http.StatusOK:
var all []byte
all, err = io.ReadAll(io.LimitReader(res.Body, 4<<10))
if s := strings.TrimSpace(string(all)); s != "" {
lines = strings.Split(s, "\n")
}
etag = res.Header.Get("Etag")
} }
all, err := io.ReadAll(io.LimitReader(res.Body, 4<<10))
return strings.Split(string(all), "\n"), err srv.mu.Lock()
defer srv.mu.Unlock()
mapSet(&srv.fetchPublicKeysCache, url, pubKeyCacheEntry{
at: srv.now(),
lines: lines,
etag: etag,
})
return lines, err
} }
// handleSSH is invoked when a new SSH connection attempt is made. // handleSSH is invoked when a new SSH connection attempt is made.
@ -523,26 +605,20 @@ func (srv *server) getSessionForContext(sctx ssh.Context) (ss *sshSession, ok bo
func (srv *server) startSession(ss *sshSession) { func (srv *server) startSession(ss *sshSession) {
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
if srv.activeSessionByH == nil {
srv.activeSessionByH = make(map[string]*sshSession)
}
if srv.activeSessionBySharedID == nil {
srv.activeSessionBySharedID = make(map[string]*sshSession)
}
if ss.idH == "" { if ss.idH == "" {
panic("empty idH") panic("empty idH")
} }
if _, dup := srv.activeSessionByH[ss.idH]; dup {
panic("dup idH")
}
if ss.sharedID == "" { if ss.sharedID == "" {
panic("empty sharedID") panic("empty sharedID")
} }
if _, dup := srv.activeSessionByH[ss.idH]; dup {
panic("dup idH")
}
if _, dup := srv.activeSessionBySharedID[ss.sharedID]; dup { if _, dup := srv.activeSessionBySharedID[ss.sharedID]; dup {
panic("dup sharedID") panic("dup sharedID")
} }
srv.activeSessionByH[ss.idH] = ss mapSet(&srv.activeSessionByH, ss.idH, ss)
srv.activeSessionBySharedID[ss.sharedID] = ss mapSet(&srv.activeSessionBySharedID, ss.sharedID, ss)
} }
// endSession unregisters s from the list of active sessions. // endSession unregisters s from the list of active sessions.
@ -1057,3 +1133,11 @@ func envEq(a, b string) bool {
} }
return a == b return a == b
} }
// mapSet assigns m[k] = v, making m if necessary.
func mapSet[K comparable, V any](m *map[K]V, k K, v V) {
if *m == nil {
*m = make(map[K]V)
}
(*m)[k] = v
}

@ -9,13 +9,19 @@ package tailssh
import ( import (
"bytes" "bytes"
"crypto/sha256"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/http"
"net/http/httptest"
"os" "os"
"os/exec" "os/exec"
"os/user" "os/user"
"reflect"
"strings" "strings"
"sync/atomic"
"testing" "testing"
"time" "time"
@ -25,6 +31,7 @@ import (
"tailscale.com/net/tsdial" "tailscale.com/net/tsdial"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tempfork/gliderlabs/ssh" "tailscale.com/tempfork/gliderlabs/ssh"
"tailscale.com/tstest"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/util/cibuild" "tailscale.com/util/cibuild"
"tailscale.com/util/lineread" "tailscale.com/util/lineread"
@ -336,3 +343,63 @@ func parseEnv(out []byte) map[string]string {
}) })
return e return e
} }
func TestPublicKeyFetching(t *testing.T) {
var reqsTotal, reqsIfNoneMatchHit, reqsIfNoneMatchMiss int32
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32((&reqsTotal), 1)
etag := fmt.Sprintf("W/%q", sha256.Sum256([]byte(r.URL.Path)))
w.Header().Set("Etag", etag)
if v := r.Header.Get("If-None-Match"); v != "" {
if v == etag {
atomic.AddInt32(&reqsIfNoneMatchHit, 1)
w.WriteHeader(304)
return
}
atomic.AddInt32(&reqsIfNoneMatchMiss, 1)
}
io.WriteString(w, "foo\nbar\n"+string(r.URL.Path)+"\n")
}))
ts.StartTLS()
defer ts.Close()
keys := ts.URL
clock := &tstest.Clock{}
srv := &server{
pubKeyHTTPClient: ts.Client(),
timeNow: clock.Now,
}
for i := 0; i < 2; i++ {
got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
if err != nil {
t.Fatal(err)
}
if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
t.Errorf("got %q; want %q", got, want)
}
}
if got, want := atomic.LoadInt32(&reqsTotal), int32(1); got != want {
t.Errorf("got %d requests; want %d", got, want)
}
if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(0); got != want {
t.Errorf("got %d etag hits; want %d", got, want)
}
clock.Advance(5 * time.Minute)
got, err := srv.fetchPublicKeysURL(keys + "/alice.keys")
if err != nil {
t.Fatal(err)
}
if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) {
t.Errorf("got %q; want %q", got, want)
}
if got, want := atomic.LoadInt32(&reqsTotal), int32(2); got != want {
t.Errorf("got %d requests; want %d", got, want)
}
if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(1); got != want {
t.Errorf("got %d etag hits; want %d", got, want)
}
if got, want := atomic.LoadInt32(&reqsIfNoneMatchMiss), int32(0); got != want {
t.Errorf("got %d etag misses; want %d", got, want)
}
}

Loading…
Cancel
Save