cmd/tailscale/cli: continue fleshing out serve CLI tests

The serve CLI doesn't exist yet, but we want nice tests for it when it
does exist.

Updates tailscale/corp#7515

Change-Id: Ib4c73d606242c4228f87410bbfd29bec52ca6c60
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/6274/head
Brad Fitzpatrick 2 years ago committed by Brad Fitzpatrick
parent b5ac9172fd
commit 0544d6ed04

@ -184,6 +184,7 @@ change in the future.
} }
if envknob.UseWIPCode() { if envknob.UseWIPCode() {
rootCmd.Subcommands = append(rootCmd.Subcommands, idTokenCmd) rootCmd.Subcommands = append(rootCmd.Subcommands, idTokenCmd)
rootCmd.Subcommands = append(rootCmd.Subcommands, serveCmd)
} }
// Don't advertise the debug command, but it exists. // Don't advertise the debug command, but it exists.

@ -4,11 +4,152 @@
package cli package cli
import "tailscale.com/ipn" import (
"context"
"encoding/json"
"flag"
"io"
"os"
func applyServeMutation(current *ipn.ServeConfig, command []string) (*ipn.ServeConfig, error) { "github.com/peterbourgon/ff/v3/ffcli"
if len(command) == 0 { "tailscale.com/ipn"
return current, nil "tailscale.com/util/mak"
)
var serveCmd = newServeCommand(&serveEnv{})
// newServeCommand returns a new "serve" subcommand using e as its environmment.
func newServeCommand(e *serveEnv) *ffcli.Command {
return &ffcli.Command{
Name: "serve",
ShortHelp: "TODO",
ShortUsage: "serve {show-config|https|tcp|ingress} <args>",
LongHelp: "", // TODO
Exec: e.runServe,
FlagSet: e.newFlags("serve", func(fs *flag.FlagSet) {}),
Subcommands: []*ffcli.Command{
{
Name: "show-config",
Exec: e.runServeShowConfig,
ShortHelp: "show current serve config",
},
{
Name: "tcp",
Exec: e.runServeTCP,
ShortHelp: "add or remove a TCP port forward",
FlagSet: e.newFlags("serve-tcp", func(fs *flag.FlagSet) {
fs.BoolVar(&e.terminateTLS, "terminate-tls", false, "terminate TLS before forwarding TCP connection")
}),
},
{
Name: "ingress",
Exec: e.runServeIngress,
ShortHelp: "enable or disable ingress",
FlagSet: e.newFlags("serve-ingress", func(fs *flag.FlagSet) {}),
},
},
}
}
// serveEnv is the environment the serve command runs within. All I/O should be
// done via serveEnv methods so that it can be faked out for tests.
//
// It also contains the flags, as registered with newServeCommand.
type serveEnv struct {
// flags
terminateTLS bool
// optional stuff for tests:
testFlagOut io.Writer
testGetServeConfig func(context.Context) (*ipn.ServeConfig, error)
testSetServeConfig func(context.Context, *ipn.ServeConfig) error
testStdout io.Writer
}
func (e *serveEnv) newFlags(name string, setup func(fs *flag.FlagSet)) *flag.FlagSet {
onError, out := flag.ExitOnError, Stderr
if e.testFlagOut != nil {
onError, out = flag.ContinueOnError, e.testFlagOut
}
fs := flag.NewFlagSet(name, onError)
fs.SetOutput(out)
if setup != nil {
setup(fs)
}
return fs
}
func (e *serveEnv) getServeConfig(ctx context.Context) (*ipn.ServeConfig, error) {
if e.testGetServeConfig != nil {
return e.testGetServeConfig(ctx)
}
return localClient.GetServeConfig(ctx)
}
func (e *serveEnv) setServeConfig(ctx context.Context, c *ipn.ServeConfig) error {
if e.testSetServeConfig != nil {
return e.testSetServeConfig(ctx, c)
}
return localClient.SetServeConfig(ctx, c)
}
func (e *serveEnv) stdout() io.Writer {
if e.testStdout != nil {
return e.testStdout
} }
return os.Stdout
}
func (e *serveEnv) runServe(ctx context.Context, args []string) error {
panic("TODO")
}
func (e *serveEnv) runServeShowConfig(ctx context.Context, args []string) error {
sc, err := e.getServeConfig(ctx)
if err != nil {
return err
}
j, err := json.MarshalIndent(sc, "", " ")
if err != nil {
return err
}
j = append(j, '\n')
e.stdout().Write(j)
return nil
}
func (e *serveEnv) runServeTCP(ctx context.Context, args []string) error {
panic("TODO") panic("TODO")
} }
func (e *serveEnv) runServeIngress(ctx context.Context, args []string) error {
if len(args) != 1 {
return flag.ErrHelp
}
var on bool
switch args[0] {
case "on", "off":
on = args[0] == "on"
default:
return flag.ErrHelp
}
sc, err := e.getServeConfig(ctx)
if err != nil {
return err
}
var key ipn.HostPort = "foo:123" // TODO(bradfitz,shayne): fix
if on && sc != nil && sc.AllowIngress[key] ||
!on && (sc == nil || !sc.AllowIngress[key]) {
// Nothing to do.
return nil
}
if sc == nil {
sc = &ipn.ServeConfig{}
}
if on {
mak.Set(&sc.AllowIngress, "foo:123", true)
} else {
delete(sc.AllowIngress, "foo:123")
}
return e.setServeConfig(ctx, sc)
}

@ -5,8 +5,13 @@
package cli package cli
import ( import (
"bytes"
"context"
"flag"
"fmt"
"reflect" "reflect"
"runtime" "runtime"
"strings"
"testing" "testing"
"tailscale.com/ipn" "tailscale.com/ipn"
@ -15,39 +20,107 @@ import (
func TestServeConfigMutations(t *testing.T) { func TestServeConfigMutations(t *testing.T) {
// Stateful mutations, starting from an empty config. // Stateful mutations, starting from an empty config.
type step struct { type step struct {
command []string // serve args command []string // serve args; nil means no command to run (only reset)
reset bool // if true, reset all ServeConfig state reset bool // if true, reset all ServeConfig state
want *ipn.ServeConfig want *ipn.ServeConfig // non-nil means we want a save of this value
wantErr string wantErr func(error) (badErrMsg string) // nil means no error is wanted
line int // line number of addStep call, for error messages line int // line number of addStep call, for error messages
} }
var steps []step var steps []step
add := func(s step) { add := func(s step) {
_, _, s.line, _ = runtime.Caller(1) _, _, s.line, _ = runtime.Caller(1)
steps = append(steps, s) steps = append(steps, s)
} }
add(step{reset: true}) add(step{reset: true})
add(step{ add(step{
want: nil, command: cmd("ingress on"),
want: &ipn.ServeConfig{AllowIngress: map[ipn.HostPort]bool{"foo:123": true}},
})
add(step{
command: cmd("ingress on"),
want: nil, // nothing to save
}) })
add(step{
command: cmd("ingress off"),
want: &ipn.ServeConfig{AllowIngress: map[ipn.HostPort]bool{}},
})
add(step{
command: cmd("ingress off"),
want: nil, // nothing to save
})
add(step{
command: cmd("ingress"),
wantErr: exactErr(flag.ErrHelp, "flag.ErrHelp"),
})
// And now run the steps above.
var current *ipn.ServeConfig var current *ipn.ServeConfig
for i, st := range steps { for i, st := range steps {
t.Logf("Executing step #%d (line %v) ... ", i, st.line)
if st.reset { if st.reset {
t.Logf("(resetting state)") t.Logf("Executing step #%d, line %v: [reset]", i, st.line)
current = nil current = nil
} }
newState, err := applyServeMutation(current, st.command) if st.command == nil {
var gotErr string continue
}
t.Logf("Executing step #%d, line %v: %q ... ", i, st.line, st.command)
var stdout bytes.Buffer
var flagOut bytes.Buffer
var newState *ipn.ServeConfig
e := &serveEnv{
testFlagOut: &flagOut,
testStdout: &stdout,
testGetServeConfig: func(context.Context) (*ipn.ServeConfig, error) {
return current, nil
},
testSetServeConfig: func(_ context.Context, c *ipn.ServeConfig) error {
newState = c
return nil
},
}
cmd := newServeCommand(e)
err := cmd.ParseAndRun(context.Background(), st.command)
if flagOut.Len() > 0 {
t.Logf("flag package output: %q", flagOut.Bytes())
}
if err != nil { if err != nil {
gotErr = err.Error() if st.wantErr == nil {
t.Fatalf("step #%d, line %v: unexpected error: %v", i, st.line, err)
}
if bad := st.wantErr(err); bad != "" {
t.Fatalf("step #%d, line %v: unexpected error: %v", i, st.line, bad)
}
continue
} }
if gotErr != st.wantErr { if st.wantErr != nil {
t.Fatalf("[%d] %v: got error %q, want %q", i, st.command, gotErr, st.wantErr) t.Fatalf("step #%d, line %v: got success (saved=%v), but wanted an error", i, st.line, newState != nil)
} }
if !reflect.DeepEqual(newState, st.want) { if !reflect.DeepEqual(newState, st.want) {
t.Fatalf("[%d] %v: bad state. got:\n%s\n\nwant:\n%s\n", t.Fatalf("[%d] %v: bad state. got:\n%s\n\nwant:\n%s\n",
i, st.command, asJSON(newState), asJSON(st.want)) i, st.command, asJSON(newState), asJSON(st.want))
} }
if newState != nil {
current = newState
}
} }
} }
// exactError returns an error checker that wants exactly the provided want error.
// If optName is non-empty, it's used in the error message.
func exactErr(want error, optName ...string) func(error) string {
return func(got error) string {
if got == want {
return ""
}
if len(optName) > 0 {
return fmt.Sprintf("got error %v, want %v", got, optName[0])
}
return fmt.Sprintf("got error %v, want %v", got, want)
}
}
func cmd(s string) []string {
return strings.Fields(s)
}

Loading…
Cancel
Save