|
|
|
@ -8,11 +8,24 @@
|
|
|
|
|
package tailssh
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"context"
|
|
|
|
|
"errors"
|
|
|
|
|
"fmt"
|
|
|
|
|
"net"
|
|
|
|
|
"os/exec"
|
|
|
|
|
"os/user"
|
|
|
|
|
"testing"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
"github.com/gliderlabs/ssh"
|
|
|
|
|
"inet.af/netaddr"
|
|
|
|
|
"tailscale.com/ipn"
|
|
|
|
|
"tailscale.com/ipn/ipnlocal"
|
|
|
|
|
"tailscale.com/net/tsdial"
|
|
|
|
|
"tailscale.com/tailcfg"
|
|
|
|
|
"tailscale.com/tstest"
|
|
|
|
|
"tailscale.com/types/logger"
|
|
|
|
|
"tailscale.com/wgengine"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
func TestMatchRule(t *testing.T) {
|
|
|
|
@ -155,3 +168,75 @@ func TestMatchRule(t *testing.T) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func timePtr(t time.Time) *time.Time { return &t }
|
|
|
|
|
|
|
|
|
|
func TestSSH(t *testing.T) {
|
|
|
|
|
ml := new(tstest.MemLogger)
|
|
|
|
|
var logf logger.Logf = ml.Logf
|
|
|
|
|
eng, err := wgengine.NewFakeUserspaceEngine(logf, 0)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
lb, err := ipnlocal.NewLocalBackend(logf, "",
|
|
|
|
|
new(ipn.MemoryStore),
|
|
|
|
|
new(tsdial.Dialer),
|
|
|
|
|
eng, 0)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
defer lb.Shutdown()
|
|
|
|
|
dir := t.TempDir()
|
|
|
|
|
lb.SetVarRoot(dir)
|
|
|
|
|
|
|
|
|
|
srv := &server{lb, logf}
|
|
|
|
|
ss, err := srv.newSSHServer()
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
u, err := user.Current()
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ci := &sshConnInfo{
|
|
|
|
|
sshUser: "test",
|
|
|
|
|
srcIP: netaddr.MustParseIP("1.2.3.4"),
|
|
|
|
|
node: &tailcfg.Node{},
|
|
|
|
|
uprof: &tailcfg.UserProfile{},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
|
defer cancel()
|
|
|
|
|
ss.Handler = func(s ssh.Session) {
|
|
|
|
|
srv.handleAcceptedSSH(ctx, s, ci, u)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ln, err := net.Listen("tcp4", "127.0.0.1:0")
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
defer ln.Close()
|
|
|
|
|
port := ln.Addr().(*net.TCPAddr).Port
|
|
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
|
for {
|
|
|
|
|
c, err := ln.Accept()
|
|
|
|
|
if err != nil {
|
|
|
|
|
if !errors.Is(err, net.ErrClosed) {
|
|
|
|
|
t.Errorf("Accept: %v", err)
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
go ss.HandleConn(c)
|
|
|
|
|
}
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
got, err := exec.Command("ssh",
|
|
|
|
|
"-p", fmt.Sprint(port),
|
|
|
|
|
"-o", "StrictHostKeyChecking=no",
|
|
|
|
|
"user@127.0.0.1", "env").CombinedOutput()
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
t.Logf("Got: %s", got)
|
|
|
|
|
}
|
|
|
|
|