@ -38,6 +38,7 @@
package distsign
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/binary"
@ -46,12 +47,17 @@ import (
"fmt"
"hash"
"io"
"log"
"net/http"
"net/url"
"os"
"time"
"github.com/hdevalence/ed25519consensus"
"golang.org/x/crypto/blake2s"
"tailscale.com/net/tshttpproxy"
"tailscale.com/types/logger"
"tailscale.com/util/must"
)
const (
@ -177,18 +183,22 @@ func (ph *PackageHash) Len() int64 { return ph.len }
// Client downloads and validates files from a distribution server.
type Client struct {
logf logger . Logf
roots [ ] ed25519 . PublicKey
pkgsAddr * url . URL
}
// NewClient returns a new client for distribution server located at pkgsAddr,
// and uses embedded root keys from the roots/ subdirectory of this package.
func NewClient ( pkgsAddr string ) ( * Client , error ) {
func NewClient ( logf logger . Logf , pkgsAddr string ) ( * Client , error ) {
if logf == nil {
logf = log . Printf
}
u , err := url . Parse ( pkgsAddr )
if err != nil {
return nil , fmt . Errorf ( "invalid pkgsAddr %q: %w" , pkgsAddr , err )
}
return & Client { roots: roots ( ) , pkgsAddr : u } , nil
return & Client { logf: logf , roots: roots ( ) , pkgsAddr : u } , nil
}
func ( c * Client ) url ( path string ) string {
@ -199,7 +209,7 @@ func (c *Client) url(path string) string {
// The file is downloaded to dstPath and its signature is validated using the
// embedded root keys. Download returns an error if anything goes wrong with
// the actual file download or with signature validation.
func ( c * Client ) Download ( srcPath, dstPath string ) error {
func ( c * Client ) Download ( ctx context . Context , srcPath, dstPath string ) error {
// Always fetch a fresh signing key.
sigPub , err := c . signingKeys ( )
if err != nil {
@ -209,11 +219,13 @@ func (c *Client) Download(srcPath, dstPath string) error {
srcURL := c . url ( srcPath )
sigURL := srcURL + ".sig"
c . logf ( "Downloading %q" , srcURL )
dstPathUnverified := dstPath + ".unverified"
hash , len , err := download ( srcURL , dstPathUnverified , downloadSizeLimit )
hash , len , err := c . download ( ctx , srcURL , dstPathUnverified , downloadSizeLimit )
if err != nil {
return err
}
c . logf ( "Downloading %q" , sigURL )
sig , err := fetch ( sigURL , signatureSizeLimit )
if err != nil {
// Best-effort clean up of downloaded package.
@ -226,6 +238,7 @@ func (c *Client) Download(srcPath, dstPath string) error {
os . Remove ( dstPathUnverified )
return fmt . Errorf ( "signature %q for file %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key" , sigURL , srcURL )
}
c . logf ( "Signature OK" )
if err := os . Rename ( dstPathUnverified , dstPath ) ; err != nil {
return fmt . Errorf ( "failed to move %q to %q after signature validation" , dstPathUnverified , dstPath )
@ -272,32 +285,84 @@ func fetch(url string, limit int64) ([]byte, error) {
// download writes the response body of url into a local file at dst, up to
// limit bytes. On success, the returned value is a BLAKE2s hash of the file.
func download ( url , dst string , limit int64 ) ( [ ] byte , int64 , error ) {
resp , err := http . Get ( url )
func ( c * Client ) download ( ctx context . Context , url , dst string , limit int64 ) ( [ ] byte , int64 , error ) {
tr := http . DefaultTransport . ( * http . Transport ) . Clone ( )
tr . Proxy = tshttpproxy . ProxyFromEnvironment
defer tr . CloseIdleConnections ( )
hc := & http . Client { Transport : tr }
quickCtx , cancel := context . WithTimeout ( ctx , 30 * time . Second )
defer cancel ( )
headReq := must . Get ( http . NewRequestWithContext ( quickCtx , http . MethodHead , url , nil ) )
res , err := hc . Do ( headReq )
if err != nil {
return nil , 0 , err
}
defer resp . Body . Close ( )
h := NewPackageHash ( )
r := io . TeeReader ( io . LimitReader ( resp . Body , limit ) , h )
if res . StatusCode != http . StatusOK {
return nil , 0 , fmt . Errorf ( "HEAD %q: %v" , url , res . Status )
}
if res . ContentLength <= 0 {
return nil , 0 , fmt . Errorf ( "HEAD %q: unexpected Content-Length %v" , url , res . ContentLength )
}
c . logf ( "Download size: %v" , res . ContentLength )
f , err := os . Create ( dst )
dlReq := must . Get ( http . NewRequestWithContext ( ctx , http . MethodGet , url , nil ) )
dlRes , err := hc . Do ( dlReq )
if err != nil {
return nil , 0 , err
}
defer f . Close ( )
defer dlRes . Body . Close ( )
// TODO(bradfitz): resume from existing partial file on disk
if dlRes . StatusCode != http . StatusOK {
return nil , 0 , fmt . Errorf ( "GET %q: %v" , url , dlRes . Status )
}
if _ , err := io . Copy ( f , r ) ; err != nil {
of , err := os . Create ( dst )
if err != nil {
return nil , 0 , err
}
if err := f . Close ( ) ; err != nil {
return nil , 0 , err
defer of . Close ( )
pw := & progressWriter { total : res . ContentLength , logf : c . logf }
h := NewPackageHash ( )
n , err := io . Copy ( io . MultiWriter ( of , h , pw ) , io . LimitReader ( dlRes . Body , limit ) )
if err != nil {
return nil , n , err
}
if n != res . ContentLength {
return nil , n , fmt . Errorf ( "GET %q: downloaded %v, want %v" , url , n , res . ContentLength )
}
if err := dlRes . Body . Close ( ) ; err != nil {
return nil , n , err
}
if err := of . Close ( ) ; err != nil {
return nil , n , err
}
pw . print ( )
return h . Sum ( nil ) , h . Len ( ) , nil
}
type progressWriter struct {
done int64
total int64
lastPrint time . Time
logf logger . Logf
}
func ( pw * progressWriter ) Write ( p [ ] byte ) ( n int , err error ) {
pw . done += int64 ( len ( p ) )
if time . Since ( pw . lastPrint ) > 2 * time . Second {
pw . print ( )
}
return len ( p ) , nil
}
func ( pw * progressWriter ) print ( ) {
pw . lastPrint = time . Now ( )
pw . logf ( "Downloaded %v/%v (%.1f%%)" , pw . done , pw . total , float64 ( pw . done ) / float64 ( pw . total ) * 100 )
}
func parsePrivateKey ( data [ ] byte , typeTag string ) ( ed25519 . PrivateKey , error ) {
b , rest := pem . Decode ( data )
if b == nil {