diff --git a/clientupdate/clientupdate.go b/clientupdate/clientupdate.go index 5355d01c3..39d1148c7 100644 --- a/clientupdate/clientupdate.go +++ b/clientupdate/clientupdate.go @@ -10,8 +10,6 @@ import ( "bufio" "bytes" "context" - "crypto/sha256" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -25,12 +23,10 @@ import ( "runtime" "strconv" "strings" - "time" "github.com/google/uuid" - "tailscale.com/net/tshttpproxy" + "tailscale.com/clientupdate/distsign" "tailscale.com/types/logger" - "tailscale.com/util/must" "tailscale.com/util/winutil" "tailscale.com/version" "tailscale.com/version/distro" @@ -88,6 +84,9 @@ type UpdateArgs struct { // if this new version should be installed. When Confirm returns false, the // update is aborted. Confirm func(newVer string) bool + // PkgsAddr is the address of the pkgs server to fetch updates from. + // Defaults to "https://pkgs.tailscale.com". + PkgsAddr string } func (args UpdateArgs) validate() error { @@ -109,6 +108,9 @@ func Update(args UpdateArgs) error { if err := args.validate(); err != nil { return err } + if args.PkgsAddr == "" { + args.PkgsAddr = "https://pkgs.tailscale.com" + } up := &updater{ UpdateArgs: args, } @@ -222,10 +224,9 @@ func (up *updater) updateSynology() error { if err != nil { return err } - url := fmt.Sprintf("https://pkgs.tailscale.com/%s/%s", up.track, spkName) - spkPath := filepath.Join(spkDir, path.Base(url)) - // TODO(awly): we should sign SPKs and validate signatures here too. - if err := up.downloadURLToFile(url, spkPath); err != nil { + pkgsPath := fmt.Sprintf("%s/%s", up.track, spkName) + spkPath := filepath.Join(spkDir, path.Base(pkgsPath)) + if err := up.downloadURLToFile(pkgsPath, spkPath); err != nil { return err } @@ -650,9 +651,9 @@ func (up *updater) updateWindows() error { if err := os.MkdirAll(msiDir, 0700); err != nil { return err } - url := fmt.Sprintf("https://pkgs.tailscale.com/%s/tailscale-setup-%s-%s.msi", up.track, ver, arch) - msiTarget := filepath.Join(msiDir, path.Base(url)) - if err := up.downloadURLToFile(url, msiTarget); err != nil { + pkgsPath := fmt.Sprintf("%s/tailscale-setup-%s-%s.msi", up.track, ver, arch) + msiTarget := filepath.Join(msiDir, path.Base(pkgsPath)) + if err := up.downloadURLToFile(pkgsPath, msiTarget); err != nil { return err } @@ -751,106 +752,12 @@ func makeSelfCopy() (tmpPathExe string, err error) { return f2.Name(), f2.Close() } -func (up *updater) downloadURLToFile(urlSrc, fileDst string) (ret error) { - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.Proxy = tshttpproxy.ProxyFromEnvironment - defer tr.CloseIdleConnections() - c := &http.Client{Transport: tr} - - quickCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - headReq := must.Get(http.NewRequestWithContext(quickCtx, "HEAD", urlSrc, nil)) - - res, err := c.Do(headReq) - if err != nil { - return err - } - if res.StatusCode != http.StatusOK { - return fmt.Errorf("HEAD %s: %v", urlSrc, res.Status) - } - if res.ContentLength <= 0 { - return fmt.Errorf("HEAD %s: unexpected Content-Length %v", urlSrc, res.ContentLength) - } - up.Logf("Download size: %v", res.ContentLength) - - hashReq := must.Get(http.NewRequestWithContext(quickCtx, "GET", urlSrc+".sha256", nil)) - hashRes, err := c.Do(hashReq) +func (up *updater) downloadURLToFile(pathSrc, fileDst string) (ret error) { + c, err := distsign.NewClient(up.Logf, up.PkgsAddr) if err != nil { return err } - hashHex, err := io.ReadAll(io.LimitReader(hashRes.Body, 100)) - hashRes.Body.Close() - if res.StatusCode != http.StatusOK { - return fmt.Errorf("GET %s.sha256: %v", urlSrc, res.Status) - } - if err != nil { - return err - } - wantHash, err := hex.DecodeString(string(strings.TrimSpace(string(hashHex)))) - if err != nil { - return err - } - hash := sha256.New() - - dlReq := must.Get(http.NewRequestWithContext(context.Background(), "GET", urlSrc, nil)) - dlRes, err := c.Do(dlReq) - if err != nil { - return err - } - // TODO(bradfitz): resume from existing partial file on disk - if dlRes.StatusCode != http.StatusOK { - return fmt.Errorf("GET %s: %v", urlSrc, dlRes.Status) - } - - of, err := os.Create(fileDst) - if err != nil { - return err - } - defer func() { - if ret != nil { - of.Close() - // TODO(bradfitz): os.Remove(fileDst) too? or keep it to resume from/debug later. - } - }() - pw := &progressWriter{total: res.ContentLength, logf: up.Logf} - n, err := io.Copy(io.MultiWriter(hash, of, pw), io.LimitReader(dlRes.Body, res.ContentLength)) - if err != nil { - return err - } - if n != res.ContentLength { - return fmt.Errorf("downloaded %v; want %v", n, res.ContentLength) - } - if err := of.Close(); err != nil { - return err - } - pw.print() - - if !bytes.Equal(hash.Sum(nil), wantHash) { - return fmt.Errorf("SHA-256 of downloaded MSI didn't match expected value") - } - up.Logf("hash matched") - - return nil -} - -type progressWriter struct { - done int64 - total int64 - lastPrint time.Time - logf logger.Logf -} - -func (pw *progressWriter) Write(p []byte) (n int, err error) { - pw.done += int64(len(p)) - if time.Since(pw.lastPrint) > 2*time.Second { - pw.print() - } - return len(p), nil -} - -func (pw *progressWriter) print() { - pw.lastPrint = time.Now() - pw.logf("Downloaded %v/%v (%.1f%%)", pw.done, pw.total, float64(pw.done)/float64(pw.total)*100) + return c.Download(context.Background(), pathSrc, fileDst) } func (up *updater) updateFreeBSD() (err error) { diff --git a/clientupdate/distsign/distsign.go b/clientupdate/distsign/distsign.go index 26ec292db..3dad1f4c4 100644 --- a/clientupdate/distsign/distsign.go +++ b/clientupdate/distsign/distsign.go @@ -38,6 +38,7 @@ package distsign import ( + "context" "crypto/ed25519" "crypto/rand" "encoding/binary" @@ -46,12 +47,17 @@ import ( "fmt" "hash" "io" + "log" "net/http" "net/url" "os" + "time" "github.com/hdevalence/ed25519consensus" "golang.org/x/crypto/blake2s" + "tailscale.com/net/tshttpproxy" + "tailscale.com/types/logger" + "tailscale.com/util/must" ) const ( @@ -177,18 +183,22 @@ func (ph *PackageHash) Len() int64 { return ph.len } // Client downloads and validates files from a distribution server. type Client struct { + logf logger.Logf roots []ed25519.PublicKey pkgsAddr *url.URL } // NewClient returns a new client for distribution server located at pkgsAddr, // and uses embedded root keys from the roots/ subdirectory of this package. -func NewClient(pkgsAddr string) (*Client, error) { +func NewClient(logf logger.Logf, pkgsAddr string) (*Client, error) { + if logf == nil { + logf = log.Printf + } u, err := url.Parse(pkgsAddr) if err != nil { return nil, fmt.Errorf("invalid pkgsAddr %q: %w", pkgsAddr, err) } - return &Client{roots: roots(), pkgsAddr: u}, nil + return &Client{logf: logf, roots: roots(), pkgsAddr: u}, nil } func (c *Client) url(path string) string { @@ -199,7 +209,7 @@ func (c *Client) url(path string) string { // The file is downloaded to dstPath and its signature is validated using the // embedded root keys. Download returns an error if anything goes wrong with // the actual file download or with signature validation. -func (c *Client) Download(srcPath, dstPath string) error { +func (c *Client) Download(ctx context.Context, srcPath, dstPath string) error { // Always fetch a fresh signing key. sigPub, err := c.signingKeys() if err != nil { @@ -209,11 +219,13 @@ func (c *Client) Download(srcPath, dstPath string) error { srcURL := c.url(srcPath) sigURL := srcURL + ".sig" + c.logf("Downloading %q", srcURL) dstPathUnverified := dstPath + ".unverified" - hash, len, err := download(srcURL, dstPathUnverified, downloadSizeLimit) + hash, len, err := c.download(ctx, srcURL, dstPathUnverified, downloadSizeLimit) if err != nil { return err } + c.logf("Downloading %q", sigURL) sig, err := fetch(sigURL, signatureSizeLimit) if err != nil { // Best-effort clean up of downloaded package. @@ -226,6 +238,7 @@ func (c *Client) Download(srcPath, dstPath string) error { os.Remove(dstPathUnverified) return fmt.Errorf("signature %q for file %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key", sigURL, srcURL) } + c.logf("Signature OK") if err := os.Rename(dstPathUnverified, dstPath); err != nil { return fmt.Errorf("failed to move %q to %q after signature validation", dstPathUnverified, dstPath) @@ -272,32 +285,84 @@ func fetch(url string, limit int64) ([]byte, error) { // download writes the response body of url into a local file at dst, up to // limit bytes. On success, the returned value is a BLAKE2s hash of the file. -func download(url, dst string, limit int64) ([]byte, int64, error) { - resp, err := http.Get(url) +func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) { + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.Proxy = tshttpproxy.ProxyFromEnvironment + defer tr.CloseIdleConnections() + hc := &http.Client{Transport: tr} + + quickCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + headReq := must.Get(http.NewRequestWithContext(quickCtx, http.MethodHead, url, nil)) + + res, err := hc.Do(headReq) if err != nil { return nil, 0, err } - defer resp.Body.Close() - - h := NewPackageHash() - r := io.TeeReader(io.LimitReader(resp.Body, limit), h) + if res.StatusCode != http.StatusOK { + return nil, 0, fmt.Errorf("HEAD %q: %v", url, res.Status) + } + if res.ContentLength <= 0 { + return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength) + } + c.logf("Download size: %v", res.ContentLength) - f, err := os.Create(dst) + dlReq := must.Get(http.NewRequestWithContext(ctx, http.MethodGet, url, nil)) + dlRes, err := hc.Do(dlReq) if err != nil { return nil, 0, err } - defer f.Close() + defer dlRes.Body.Close() + // TODO(bradfitz): resume from existing partial file on disk + if dlRes.StatusCode != http.StatusOK { + return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status) + } - if _, err := io.Copy(f, r); err != nil { + of, err := os.Create(dst) + if err != nil { return nil, 0, err } - if err := f.Close(); err != nil { - return nil, 0, err + defer of.Close() + pw := &progressWriter{total: res.ContentLength, logf: c.logf} + h := NewPackageHash() + n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit)) + if err != nil { + return nil, n, err + } + if n != res.ContentLength { + return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, res.ContentLength) } + if err := dlRes.Body.Close(); err != nil { + return nil, n, err + } + if err := of.Close(); err != nil { + return nil, n, err + } + pw.print() return h.Sum(nil), h.Len(), nil } +type progressWriter struct { + done int64 + total int64 + lastPrint time.Time + logf logger.Logf +} + +func (pw *progressWriter) Write(p []byte) (n int, err error) { + pw.done += int64(len(p)) + if time.Since(pw.lastPrint) > 2*time.Second { + pw.print() + } + return len(p), nil +} + +func (pw *progressWriter) print() { + pw.lastPrint = time.Now() + pw.logf("Downloaded %v/%v (%.1f%%)", pw.done, pw.total, float64(pw.done)/float64(pw.total)*100) +} + func parsePrivateKey(data []byte, typeTag string) (ed25519.PrivateKey, error) { b, rest := pem.Decode(data) if b == nil { diff --git a/clientupdate/distsign/distsign_test.go b/clientupdate/distsign/distsign_test.go index 16f71b2e8..c7f5f023d 100644 --- a/clientupdate/distsign/distsign_test.go +++ b/clientupdate/distsign/distsign_test.go @@ -5,6 +5,7 @@ package distsign import ( "bytes" + "context" "crypto/ed25519" "net/http" "net/http/httptest" @@ -97,7 +98,7 @@ func TestDownload(t *testing.T) { t.Cleanup(func() { os.Remove(dst) }) - err := c.Download(tt.src, dst) + err := c.Download(context.Background(), tt.src, dst) if err != nil { if tt.wantErr { return @@ -121,9 +122,10 @@ func TestDownload(t *testing.T) { func TestRotateRoot(t *testing.T) { srv := newTestServer(t) c1 := srv.client(t) + ctx := context.Background() srv.addSigned("hello", []byte("world")) - if err := c1.Download("hello", filepath.Join(t.TempDir(), "hello")); err != nil { + if err := c1.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil { t.Fatalf("Download failed on a fresh server: %v", err) } @@ -132,13 +134,13 @@ func TestRotateRoot(t *testing.T) { // Old client can still download files because it still trusts the old // root key. - if err := c1.Download("hello", filepath.Join(t.TempDir(), "hello")); err != nil { + if err := c1.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil { t.Fatalf("Download failed after root rotation on old client: %v", err) } // New client should fail download because current signing key is signed by // the revoked root that new client doesn't trust. c2 := srv.client(t) - if err := c2.Download("hello", filepath.Join(t.TempDir(), "hello")); err == nil { + if err := c2.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err == nil { t.Fatalf("Download succeeded on new client, but signing key is signed with revoked root key") } // Re-sign signing key with another valid root that client still trusts. @@ -147,10 +149,10 @@ func TestRotateRoot(t *testing.T) { // // Note: we don't need to re-sign the "hello" file because signing key // didn't change (only signing key's signature). - if err := c1.Download("hello", filepath.Join(t.TempDir(), "hello")); err != nil { + if err := c1.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil { t.Fatalf("Download failed after root rotation on old client with re-signed signing key: %v", err) } - if err := c2.Download("hello", filepath.Join(t.TempDir(), "hello")); err != nil { + if err := c2.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil { t.Fatalf("Download failed after root rotation on new client with re-signed signing key: %v", err) } } @@ -158,46 +160,47 @@ func TestRotateRoot(t *testing.T) { func TestRotateSigning(t *testing.T) { srv := newTestServer(t) c := srv.client(t) + ctx := context.Background() srv.addSigned("hello", []byte("world")) - if err := c.Download("hello", filepath.Join(t.TempDir(), "hello")); err != nil { + if err := c.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil { t.Fatalf("Download failed on a fresh server: %v", err) } // Replace signing key but don't publish it yet. srv.sign = append(srv.sign, newSigningKeyPair(t)) - if err := c.Download("hello", filepath.Join(t.TempDir(), "hello")); err != nil { + if err := c.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil { t.Fatalf("Download failed after new signing key added but before publishing it: %v", err) } // Publish new signing key bundle with both keys. srv.resignSigningKeys() - if err := c.Download("hello", filepath.Join(t.TempDir(), "hello")); err != nil { + if err := c.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil { t.Fatalf("Download failed after new signing key was published: %v", err) } // Re-sign the "hello" file with new signing key. srv.add("hello.sig", srv.sign[1].sign([]byte("world"))) - if err := c.Download("hello", filepath.Join(t.TempDir(), "hello")); err != nil { + if err := c.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil { t.Fatalf("Download failed after re-signing with new signing key: %v", err) } // Drop the old signing key. srv.sign = srv.sign[1:] srv.resignSigningKeys() - if err := c.Download("hello", filepath.Join(t.TempDir(), "hello")); err != nil { + if err := c.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil { t.Fatalf("Download failed after removing old signing key: %v", err) } // Add another key and re-sign the file with it *before* publishing. srv.sign = append(srv.sign, newSigningKeyPair(t)) srv.add("hello.sig", srv.sign[1].sign([]byte("world"))) - if err := c.Download("hello", filepath.Join(t.TempDir(), "hello")); err == nil { + if err := c.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err == nil { t.Fatalf("Download succeeded when signed with a not-yet-published signing key") } // Fix this by publishing the new key. srv.resignSigningKeys() - if err := c.Download("hello", filepath.Join(t.TempDir(), "hello")); err != nil { + if err := c.Download(ctx, "hello", filepath.Join(t.TempDir(), "hello")); err != nil { t.Fatalf("Download failed after publishing new signing key: %v", err) } } @@ -355,6 +358,7 @@ func (s *testServer) client(t *testing.T) *Client { t.Fatal(err) } return &Client{ + logf: t.Logf, roots: roots, pkgsAddr: u, } diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index 88ab152bc..e827b81a9 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -24,7 +24,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep github.com/google/uuid from tailscale.com/util/quarantine+ github.com/gorilla/csrf from tailscale.com/client/web github.com/gorilla/securecookie from github.com/gorilla/csrf - github.com/hdevalence/ed25519consensus from tailscale.com/tka + github.com/hdevalence/ed25519consensus from tailscale.com/tka+ L github.com/josharian/native from github.com/mdlayher/netlink+ L 💣 github.com/jsimonetti/rtnetlink from tailscale.com/net/interfaces+ L github.com/jsimonetti/rtnetlink/internal/unix from github.com/jsimonetti/rtnetlink @@ -73,6 +73,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/client/tailscale/apitype from tailscale.com/cmd/tailscale/cli+ tailscale.com/client/web from tailscale.com/cmd/tailscale/cli tailscale.com/clientupdate from tailscale.com/cmd/tailscale/cli + tailscale.com/clientupdate/distsign from tailscale.com/clientupdate tailscale.com/cmd/tailscale/cli from tailscale.com/cmd/tailscale tailscale.com/control/controlbase from tailscale.com/control/controlhttp tailscale.com/control/controlhttp from tailscale.com/cmd/tailscale/cli