@ -8,12 +8,15 @@ package tailssh
import (
import (
"bytes"
"bytes"
"context"
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rand"
"encoding/json"
"encoding/json"
"errors"
"errors"
"fmt"
"fmt"
"io"
"io"
"log"
"net"
"net"
"net/http"
"net/http"
"net/http/httptest"
"net/http/httptest"
@ -41,7 +44,7 @@ import (
"tailscale.com/sessionrecording"
"tailscale.com/sessionrecording"
"tailscale.com/tailcfg"
"tailscale.com/tailcfg"
"tailscale.com/tempfork/gliderlabs/ssh"
"tailscale.com/tempfork/gliderlabs/ssh"
ssh test "tailscale.com/tempfork/sshtest/ssh"
testssh "tailscale.com/tempfork/sshtest/ssh"
"tailscale.com/tsd"
"tailscale.com/tsd"
"tailscale.com/tstest"
"tailscale.com/tstest"
"tailscale.com/types/key"
"tailscale.com/types/key"
@ -56,8 +59,6 @@ import (
"tailscale.com/wgengine"
"tailscale.com/wgengine"
)
)
type _ = sshtest . Client // TODO(bradfitz,percy): sshtest; delete this line
func TestMatchRule ( t * testing . T ) {
func TestMatchRule ( t * testing . T ) {
someAction := new ( tailcfg . SSHAction )
someAction := new ( tailcfg . SSHAction )
tests := [ ] struct {
tests := [ ] struct {
@ -510,9 +511,9 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
defer s . Shutdown ( )
defer s . Shutdown ( )
const sshUser = "alice"
const sshUser = "alice"
cfg := & go ssh. ClientConfig {
cfg := & test ssh. ClientConfig {
User : sshUser ,
User : sshUser ,
HostKeyCallback : go ssh. InsecureIgnoreHostKey ( ) ,
HostKeyCallback : test ssh. InsecureIgnoreHostKey ( ) ,
}
}
tests := [ ] struct {
tests := [ ] struct {
@ -559,12 +560,12 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
wg . Add ( 1 )
wg . Add ( 1 )
go func ( ) {
go func ( ) {
defer wg . Done ( )
defer wg . Done ( )
c , chans , reqs , err := go ssh. NewClientConn ( sc , sc . RemoteAddr ( ) . String ( ) , cfg )
c , chans , reqs , err := test ssh. NewClientConn ( sc , sc . RemoteAddr ( ) . String ( ) , cfg )
if err != nil {
if err != nil {
t . Errorf ( "client: %v" , err )
t . Errorf ( "client: %v" , err )
return
return
}
}
client := go ssh. NewClient ( c , chans , reqs )
client := test ssh. NewClient ( c , chans , reqs )
defer client . Close ( )
defer client . Close ( )
session , err := client . NewSession ( )
session , err := client . NewSession ( )
if err != nil {
if err != nil {
@ -645,21 +646,21 @@ func TestMultipleRecorders(t *testing.T) {
sc , dc := memnet . NewTCPConn ( src , dst , 1024 )
sc , dc := memnet . NewTCPConn ( src , dst , 1024 )
const sshUser = "alice"
const sshUser = "alice"
cfg := & go ssh. ClientConfig {
cfg := & test ssh. ClientConfig {
User : sshUser ,
User : sshUser ,
HostKeyCallback : go ssh. InsecureIgnoreHostKey ( ) ,
HostKeyCallback : test ssh. InsecureIgnoreHostKey ( ) ,
}
}
var wg sync . WaitGroup
var wg sync . WaitGroup
wg . Add ( 1 )
wg . Add ( 1 )
go func ( ) {
go func ( ) {
defer wg . Done ( )
defer wg . Done ( )
c , chans , reqs , err := go ssh. NewClientConn ( sc , sc . RemoteAddr ( ) . String ( ) , cfg )
c , chans , reqs , err := test ssh. NewClientConn ( sc , sc . RemoteAddr ( ) . String ( ) , cfg )
if err != nil {
if err != nil {
t . Errorf ( "client: %v" , err )
t . Errorf ( "client: %v" , err )
return
return
}
}
client := go ssh. NewClient ( c , chans , reqs )
client := test ssh. NewClient ( c , chans , reqs )
defer client . Close ( )
defer client . Close ( )
session , err := client . NewSession ( )
session , err := client . NewSession ( )
if err != nil {
if err != nil {
@ -736,21 +737,21 @@ func TestSSHRecordingNonInteractive(t *testing.T) {
sc , dc := memnet . NewTCPConn ( src , dst , 1024 )
sc , dc := memnet . NewTCPConn ( src , dst , 1024 )
const sshUser = "alice"
const sshUser = "alice"
cfg := & go ssh. ClientConfig {
cfg := & test ssh. ClientConfig {
User : sshUser ,
User : sshUser ,
HostKeyCallback : go ssh. InsecureIgnoreHostKey ( ) ,
HostKeyCallback : test ssh. InsecureIgnoreHostKey ( ) ,
}
}
var wg sync . WaitGroup
var wg sync . WaitGroup
wg . Add ( 1 )
wg . Add ( 1 )
go func ( ) {
go func ( ) {
defer wg . Done ( )
defer wg . Done ( )
c , chans , reqs , err := go ssh. NewClientConn ( sc , sc . RemoteAddr ( ) . String ( ) , cfg )
c , chans , reqs , err := test ssh. NewClientConn ( sc , sc . RemoteAddr ( ) . String ( ) , cfg )
if err != nil {
if err != nil {
t . Errorf ( "client: %v" , err )
t . Errorf ( "client: %v" , err )
return
return
}
}
client := go ssh. NewClient ( c , chans , reqs )
client := test ssh. NewClient ( c , chans , reqs )
defer client . Close ( )
defer client . Close ( )
session , err := client . NewSession ( )
session , err := client . NewSession ( )
if err != nil {
if err != nil {
@ -886,49 +887,81 @@ func TestSSHAuthFlow(t *testing.T) {
} ,
} ,
}
}
s := & server {
s := & server {
logf : log ger. Discard ,
logf : log . Printf ,
}
}
defer s . Shutdown ( )
defer s . Shutdown ( )
src , dst := must . Get ( netip . ParseAddrPort ( "100.100.100.101:2231" ) ) , must . Get ( netip . ParseAddrPort ( "100.100.100.102:22" ) )
src , dst := must . Get ( netip . ParseAddrPort ( "100.100.100.101:2231" ) ) , must . Get ( netip . ParseAddrPort ( "100.100.100.102:22" ) )
for _ , tc := range tests {
for _ , tc := range tests {
t . Run ( tc . name , func ( t * testing . T ) {
for _ , authMethods := range [ ] [ ] string { nil , { "publickey" , "password" } , { "password" , "publickey" } } {
t . Run ( fmt . Sprintf ( "%s-skip-none-auth-%v" , tc . name , strings . Join ( authMethods , "-then-" ) ) , func ( t * testing . T ) {
sc , dc := memnet . NewTCPConn ( src , dst , 1024 )
sc , dc := memnet . NewTCPConn ( src , dst , 1024 )
s . lb = tc . state
s . lb = tc . state
sshUser := "alice"
sshUser := "alice"
if tc . sshUser != "" {
if tc . sshUser != "" {
sshUser = tc . sshUser
sshUser = tc . sshUser
}
}
wantBanners := slices . Clone ( tc . wantBanners )
noneAuthEnabled := len ( authMethods ) == 0
var publicKeyUsed atomic . Bool
var passwordUsed atomic . Bool
var passwordUsed atomic . Bool
cfg := & gossh . ClientConfig {
var methods [ ] testssh . AuthMethod
User : sshUser ,
HostKeyCallback : gossh . InsecureIgnoreHostKey ( ) ,
for _ , authMethod := range authMethods {
Auth : [ ] gossh . AuthMethod {
switch authMethod {
gossh . PasswordCallback ( func ( ) ( secret string , err error ) {
case "publickey" :
if ! tc . usesPassword {
methods = append ( methods ,
t . Error ( "unexpected use of PasswordCallback" )
testssh . PublicKeysCallback ( func ( ) ( signers [ ] testssh . Signer , err error ) {
return "" , errors . New ( "unexpected use of PasswordCallback" )
publicKeyUsed . Store ( true )
key , err := ecdsa . GenerateKey ( elliptic . P384 ( ) , rand . Reader )
if err != nil {
return nil , err
}
sig , err := testssh . NewSignerFromKey ( key )
if err != nil {
return nil , err
}
}
return [ ] testssh . Signer { sig } , nil
} ) )
case "password" :
methods = append ( methods , testssh . PasswordCallback ( func ( ) ( secret string , err error ) {
passwordUsed . Store ( true )
passwordUsed . Store ( true )
return "any-pass" , nil
return "any-pass" , nil
} ) ,
} ) )
} ,
}
}
if noneAuthEnabled && tc . usesPassword {
methods = append ( methods , testssh . PasswordCallback ( func ( ) ( secret string , err error ) {
passwordUsed . Store ( true )
return "any-pass" , nil
} ) )
}
cfg := & testssh . ClientConfig {
User : sshUser ,
HostKeyCallback : testssh . InsecureIgnoreHostKey ( ) ,
SkipNoneAuth : ! noneAuthEnabled ,
Auth : methods ,
BannerCallback : func ( message string ) error {
BannerCallback : func ( message string ) error {
if len ( tc . wantBanners ) == 0 {
if len ( wantBanners ) == 0 {
t . Errorf ( "unexpected banner: %q" , message )
t . Errorf ( "unexpected banner: %q" , message )
} else if message != tc . wantBanners [ 0 ] {
} else if message != wantBanners [ 0 ] {
t . Errorf ( "banner = %q; want %q" , message , tc . wantBanners [ 0 ] )
t . Errorf ( "banner = %q; want %q" , message , wantBanners [ 0 ] )
} else {
} else {
t . Logf ( "banner = %q" , message )
t . Logf ( "banner = %q" , message )
tc . wantBanners = tc . wantBanners [ 1 : ]
wantBanners = wantBanners [ 1 : ]
}
}
return nil
return nil
} ,
} ,
}
}
var wg sync . WaitGroup
var wg sync . WaitGroup
wg . Add ( 1 )
wg . Add ( 1 )
go func ( ) {
go func ( ) {
defer wg . Done ( )
defer wg . Done ( )
c , chans , reqs , err := go ssh. NewClientConn ( sc , sc . RemoteAddr ( ) . String ( ) , cfg )
c , chans , reqs , err := test ssh. NewClientConn ( sc , sc . RemoteAddr ( ) . String ( ) , cfg )
if err != nil {
if err != nil {
if ! tc . authErr {
if ! tc . authErr {
t . Errorf ( "client: %v" , err )
t . Errorf ( "client: %v" , err )
@ -939,7 +972,7 @@ func TestSSHAuthFlow(t *testing.T) {
t . Errorf ( "client: expected error, got nil" )
t . Errorf ( "client: expected error, got nil" )
return
return
}
}
client := go ssh. NewClient ( c , chans , reqs )
client := test ssh. NewClient ( c , chans , reqs )
defer client . Close ( )
defer client . Close ( )
session , err := client . NewSession ( )
session , err := client . NewSession ( )
if err != nil {
if err != nil {
@ -956,11 +989,50 @@ func TestSSHAuthFlow(t *testing.T) {
t . Errorf ( "unexpected error: %v" , err )
t . Errorf ( "unexpected error: %v" , err )
}
}
wg . Wait ( )
wg . Wait ( )
if len ( tc . wantBanners ) > 0 {
if len ( wantBanners ) > 0 {
t . Errorf ( "missing banners: %v" , tc . wantBanners )
t . Errorf ( "missing banners: %v" , wantBanners )
}
// Check to see which callbacks were invoked.
//
// When `none` auth is enabled, the public key callback should
// never fire, and the password callback should only fire if
// authentication succeeded and the client was trying to force
// password authentication by connecting with the '-password'
// username suffix.
//
// When skipping `none` auth, the first callback should always
// fire, and the 2nd callback should fire only if
// authentication failed.
wantPublicKey := false
wantPassword := false
if noneAuthEnabled {
wantPassword = ! tc . authErr && tc . usesPassword
} else {
for i , authMethod := range authMethods {
switch authMethod {
case "publickey" :
wantPublicKey = i == 0 || tc . authErr
case "password" :
wantPassword = i == 0 || tc . authErr
}
}
}
if wantPublicKey && ! publicKeyUsed . Load ( ) {
t . Error ( "public key should have been attempted" )
} else if ! wantPublicKey && publicKeyUsed . Load ( ) {
t . Errorf ( "public key should not have been attempted" )
}
if wantPassword && ! passwordUsed . Load ( ) {
t . Error ( "password should have been attempted" )
} else if ! wantPassword && passwordUsed . Load ( ) {
t . Error ( "password should not have been attempted" )
}
}
} )
} )
}
}
}
}
}
func TestSSH ( t * testing . T ) {
func TestSSH ( t * testing . T ) {