@ -8,14 +8,13 @@ package tailssh
import (
"bufio"
"bytes"
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
"log"
@ -24,6 +23,7 @@ import (
"net/netip"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"testing"
@ -34,8 +34,10 @@ import (
"github.com/pkg/sftp"
gossh "github.com/tailscale/golang-x-crypto/ssh"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"tailscale.com/net/tsdial"
"tailscale.com/tailcfg"
glider "tailscale.com/tempfork/gliderlabs/ssh"
"tailscale.com/types/key"
"tailscale.com/types/netmap"
"tailscale.com/util/set"
@ -300,6 +302,95 @@ func TestIntegrationSCP(t *testing.T) {
}
}
func TestSSHAgentForwarding ( t * testing . T ) {
debugTest . Store ( true )
t . Cleanup ( func ( ) {
debugTest . Store ( false )
} )
// Create a client SSH key
tmpDir , err := os . MkdirTemp ( "" , "" )
if err != nil {
t . Fatal ( err )
}
t . Cleanup ( func ( ) {
_ = os . RemoveAll ( tmpDir )
} )
pkFile := filepath . Join ( tmpDir , "pk" )
clientKey , clientKeyRSA := generateClientKey ( t , pkFile )
// Start upstream SSH server
l , err := net . Listen ( "tcp" , "127.0.0.1:" )
if err != nil {
t . Fatalf ( "unable to listen for SSH: %s" , err )
}
t . Cleanup ( func ( ) {
_ = l . Close ( )
} )
// Run an SSH server that accepts connections from that client SSH key.
gs := glider . Server {
Handler : func ( s glider . Session ) {
io . WriteString ( s , "Hello world\n" )
} ,
PublicKeyHandler : func ( ctx glider . Context , key glider . PublicKey ) error {
// Note - this is not meant to be cryptographically secure, it's
// just checking that SSH agent forwarding is forwarding the right
// key.
a := key . Marshal ( )
b := clientKey . PublicKey ( ) . Marshal ( )
if ! bytes . Equal ( a , b ) {
return errors . New ( "key mismatch" )
}
return nil
} ,
}
go gs . Serve ( l )
// Run tailscale SSH server and connect to it
username := "testuser"
tailscaleAddr := testServer ( t , username , false ) // TODO: make this false to use V2 behavior
tcl , err := ssh . Dial ( "tcp" , tailscaleAddr , & ssh . ClientConfig {
HostKeyCallback : ssh . InsecureIgnoreHostKey ( ) ,
} )
if err != nil {
t . Fatal ( err )
}
t . Cleanup ( func ( ) { tcl . Close ( ) } )
s , err := tcl . NewSession ( )
if err != nil {
t . Fatal ( err )
}
// Set up SSH agent forwarding on the client
err = agent . RequestAgentForwarding ( s )
if err != nil {
t . Fatal ( err )
}
keyring := agent . NewKeyring ( )
keyring . Add ( agent . AddedKey {
PrivateKey : clientKeyRSA ,
} )
err = agent . ForwardToAgent ( tcl , keyring )
if err != nil {
t . Fatal ( err )
}
// Attempt to SSH to the upstream test server using the forwarded SSH key
// and run the "true" command.
upstreamHost , upstreamPort , err := net . SplitHostPort ( l . Addr ( ) . String ( ) )
if err != nil {
t . Fatal ( err )
}
o , err := s . CombinedOutput ( fmt . Sprintf ( ` ssh -T -o StrictHostKeyChecking=no -p %s upstreamuser@%s "true" ` , upstreamPort , upstreamHost ) )
if err != nil {
t . Fatalf ( "unable to call true command: %s\n%s" , err , o )
}
}
func fallbackToSUAvailable ( ) bool {
if runtime . GOOS != "linux" {
return false
@ -374,10 +465,25 @@ readLoop:
return string ( _got )
}
func testClient ( t * testing . T , forceV1Behavior bool ) * ssh . Client {
func testClient ( t * testing . T , forceV1Behavior bool , authMethods ... ssh . AuthMethod ) * ssh . Client {
t . Helper ( )
username := "testuser"
addr := testServer ( t , username , forceV1Behavior )
cl , err := ssh . Dial ( "tcp" , addr , & ssh . ClientConfig {
HostKeyCallback : ssh . InsecureIgnoreHostKey ( ) ,
Auth : authMethods ,
} )
if err != nil {
t . Fatal ( err )
}
t . Cleanup ( func ( ) { cl . Close ( ) } )
return cl
}
func testServer ( t * testing . T , username string , forceV1Behavior bool ) string {
srv := & server {
lb : & testBackend { localUser : username , forceV1Behavior : forceV1Behavior } ,
logf : log . Printf ,
@ -392,21 +498,15 @@ func testClient(t *testing.T, forceV1Behavior bool) *ssh.Client {
t . Cleanup ( func ( ) { l . Close ( ) } )
go func ( ) {
for {
conn , err := l . Accept ( )
if err == nil {
go srv . HandleSSHConn ( & addressFakingConn { conn } )
}
} ( )
cl , err := ssh . Dial ( "tcp" , l . Addr ( ) . String ( ) , & ssh . ClientConfig {
HostKeyCallback : ssh . InsecureIgnoreHostKey ( ) ,
} )
if err != nil {
log . Fatal ( err )
}
t . Cleanup ( func ( ) { cl . Close ( ) } )
} ( )
return c l
return l . Addr ( ) . String ( )
}
func testSession ( t * testing . T , forceV1Behavior bool ) * session {
@ -417,7 +517,7 @@ func testSession(t *testing.T, forceV1Behavior bool) *session {
func testSessionFor ( t * testing . T , cl * ssh . Client ) * session {
s , err := cl . NewSession ( )
if err != nil {
log . Fatal ( err )
t . Fatal ( err )
}
t . Cleanup ( func ( ) { s . Close ( ) } )
@ -435,6 +535,31 @@ func testSessionFor(t *testing.T, cl *ssh.Client) *session {
}
}
func generateClientKey ( t * testing . T , privateKeyFile string ) ( ssh . Signer , * rsa . PrivateKey ) {
t . Helper ( )
priv , err := rsa . GenerateKey ( rand . Reader , 2048 )
if err != nil {
t . Fatal ( err )
}
mk , err := x509 . MarshalPKCS8PrivateKey ( priv )
if err != nil {
t . Fatal ( err )
}
privateKey := pem . EncodeToMemory ( & pem . Block { Type : "PRIVATE KEY" , Bytes : mk } )
if privateKey == nil {
t . Fatal ( "failed to encoded private key" )
}
err = os . WriteFile ( privateKeyFile , privateKey , 0600 )
if err != nil {
t . Fatal ( err )
}
signer , err := ssh . ParsePrivateKey ( privateKey )
if err != nil {
t . Fatal ( err )
}
return signer , priv
}
// testBackend implements ipnLocalBackend
type testBackend struct {
localUser string
@ -443,19 +568,10 @@ type testBackend struct {
func ( tb * testBackend ) GetSSH_HostKeys ( ) ( [ ] gossh . Signer , error ) {
var result [ ] gossh . Signer
for _ , typ := range [ ] string { "ed25519" , "ecdsa" , "rsa" } {
var priv any
var err error
switch typ {
case "ed25519" :
_ , priv , err = ed25519 . GenerateKey ( rand . Reader )
case "ecdsa" :
curve := elliptic . P256 ( )
priv , err = ecdsa . GenerateKey ( curve , rand . Reader )
case "rsa" :
const keySize = 2048
priv , err = rsa . GenerateKey ( rand . Reader , keySize )
}
if err != nil {
return nil , err
}
@ -469,7 +585,6 @@ func (tb *testBackend) GetSSH_HostKeys() ([]gossh.Signer, error) {
return nil , err
}
result = append ( result , signer )
}
return result , nil
}
@ -487,7 +602,7 @@ func (tb *testBackend) NetMap() *netmap.NetworkMap {
Rules : [ ] * tailcfg . SSHRule {
{
Principals : [ ] * tailcfg . SSHPrincipal { { Any : true } } ,
Action : & tailcfg . SSHAction { Accept : true },
Action : & tailcfg . SSHAction { Accept : true , AllowAgentForwarding : true },
SSHUsers : map [ string ] string { "*" : tb . localUser } ,
} ,
} ,