You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
tailscale/ipn/ipnlocal/serve_test.go

590 lines
15 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package ipnlocal
import (
"bytes"
"context"
"crypto/sha256"
"crypto/tls"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/netip"
"net/url"
"os"
"path/filepath"
"strings"
"testing"
"time"
"tailscale.com/ipn"
"tailscale.com/ipn/store/mem"
"tailscale.com/tailcfg"
"tailscale.com/tsd"
"tailscale.com/types/logid"
"tailscale.com/types/netmap"
"tailscale.com/util/cmpx"
"tailscale.com/util/mak"
"tailscale.com/util/must"
"tailscale.com/wgengine"
)
func TestExpandProxyArg(t *testing.T) {
type res struct {
target string
insecure bool
}
tests := []struct {
in string
want res
}{
{"", res{}},
{"3030", res{"http://127.0.0.1:3030", false}},
{"localhost:3030", res{"http://localhost:3030", false}},
{"10.2.3.5:3030", res{"http://10.2.3.5:3030", false}},
{"http://foo.com", res{"http://foo.com", false}},
{"https://foo.com", res{"https://foo.com", false}},
{"https+insecure://10.2.3.4", res{"https://10.2.3.4", true}},
}
for _, tt := range tests {
target, insecure := expandProxyArg(tt.in)
got := res{target, insecure}
if got != tt.want {
t.Errorf("expandProxyArg(%q) = %v, want %v", tt.in, got, tt.want)
}
}
}
func TestGetServeHandler(t *testing.T) {
const serverName = "example.ts.net"
conf1 := &ipn.ServeConfig{
Web: map[ipn.HostPort]*ipn.WebServerConfig{
serverName + ":443": {
Handlers: map[string]*ipn.HTTPHandler{
"/": {},
"/bar": {},
"/foo/": {},
"/foo/bar": {},
"/foo/bar/": {},
},
},
},
}
tests := []struct {
name string
port uint16 // or 443 is zero
path string // http.Request.URL.Path
conf *ipn.ServeConfig
want string // mountPoint
}{
{
name: "nothing",
path: "/",
conf: nil,
want: "",
},
{
name: "root",
conf: conf1,
path: "/",
want: "/",
},
{
name: "root-other",
conf: conf1,
path: "/other",
want: "/",
},
{
name: "bar",
conf: conf1,
path: "/bar",
want: "/bar",
},
{
name: "foo-bar",
conf: conf1,
path: "/foo/bar",
want: "/foo/bar",
},
{
name: "foo-bar-slash",
conf: conf1,
path: "/foo/bar/",
want: "/foo/bar/",
},
{
name: "foo-bar-other",
conf: conf1,
path: "/foo/bar/other",
want: "/foo/bar/",
},
{
name: "foo-other",
conf: conf1,
path: "/foo/other",
want: "/foo/",
},
{
name: "foo-no-trailing-slash",
conf: conf1,
path: "/foo",
want: "/foo/",
},
{
name: "dot-dots",
conf: conf1,
path: "/foo/../../../../../../../../etc/passwd",
want: "/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b := &LocalBackend{
serveConfig: tt.conf.View(),
logf: t.Logf,
}
req := &http.Request{
URL: &url.URL{
Path: tt.path,
},
TLS: &tls.ConnectionState{ServerName: serverName},
}
port := cmpx.Or(tt.port, 443)
req = req.WithContext(context.WithValue(req.Context(), serveHTTPContextKey{}, &serveHTTPContext{
DestPort: port,
}))
h, got, ok := b.getServeHandler(req)
if (got != "") != ok {
t.Fatalf("got ok=%v, but got mountPoint=%q", ok, got)
}
if h.Valid() != ok {
t.Fatalf("got ok=%v, but valid=%v", ok, h.Valid())
}
if got != tt.want {
t.Errorf("got handler at mount %q, want %q", got, tt.want)
}
})
}
}
func getEtag(t *testing.T, b any) string {
t.Helper()
bts, err := json.Marshal(b)
if err != nil {
t.Fatal(err)
}
sum := sha256.Sum256(bts)
return hex.EncodeToString(sum[:])
}
// TestServeConfigForeground tests the inter-dependency
// between a ServeConfig and a WatchIPNBus:
// 1. Creating a WatchIPNBus returns a sessionID, that
// 2. ServeConfig sets it as the key of the Foreground field.
// 3. ServeConfig expects the WatchIPNBus to clean up the Foreground
// config when the session is done.
// 4. WatchIPNBus expects the ServeConfig to send a signal (close the channel)
// if an incoming SetServeConfig removes previous foregrounds.
func TestServeConfigForeground(t *testing.T) {
b := newTestBackend(t)
ch1 := make(chan string, 1)
go func() {
defer close(ch1)
b.WatchNotifications(context.Background(), ipn.NotifyInitialState, nil, func(roNotify *ipn.Notify) (keepGoing bool) {
if roNotify.SessionID != "" {
ch1 <- roNotify.SessionID
}
return true
})
}()
ch2 := make(chan string, 1)
go func() {
b.WatchNotifications(context.Background(), ipn.NotifyInitialState, nil, func(roNotify *ipn.Notify) (keepGoing bool) {
if roNotify.SessionID != "" {
ch2 <- roNotify.SessionID
return true
}
ch2 <- "again" // let channel know fn was called again
return true
})
}()
var session1 string
select {
case session1 = <-ch1:
case <-time.After(time.Second):
t.Fatal("timed out waiting on watch notifications session id")
}
var session2 string
select {
case session2 = <-ch2:
case <-time.After(time.Second):
t.Fatal("timed out waiting on watch notifications session id")
}
err := b.SetServeConfig(&ipn.ServeConfig{
Foreground: map[string]*ipn.ServeConfig{
session1: {TCP: map[uint16]*ipn.TCPPortHandler{
443: {TCPForward: "http://localhost:3000"}},
},
session2: {TCP: map[uint16]*ipn.TCPPortHandler{
999: {TCPForward: "http://localhost:4000"}},
},
},
}, "")
if err != nil {
t.Fatal(err)
}
// Setting a new serve config should shut down WatchNotifications
// whose session IDs are no longer found: session1 goes, session2 stays.
err = b.SetServeConfig(&ipn.ServeConfig{
TCP: map[uint16]*ipn.TCPPortHandler{
5000: {TCPForward: "http://localhost:5000"},
},
Foreground: map[string]*ipn.ServeConfig{
session2: {TCP: map[uint16]*ipn.TCPPortHandler{
999: {TCPForward: "http://localhost:4000"}},
},
},
}, "")
if err != nil {
t.Fatal(err)
}
select {
case _, ok := <-ch1:
if ok {
t.Fatal("expected channel to be closed")
}
case <-time.After(time.Second):
t.Fatal("timed out waiting on watch notifications closing")
}
// check that the second session is still running
b.send(ipn.Notify{})
select {
case _, ok := <-ch2:
if !ok {
t.Fatal("expected second session to remain open")
}
case <-time.After(time.Second):
t.Fatal("timed out waiting on second session")
}
}
func TestServeConfigETag(t *testing.T) {
b := newTestBackend(t)
// a nil config with initial etag should succeed
err := b.SetServeConfig(nil, getEtag(t, nil))
if err != nil {
t.Fatal(err)
}
// a nil config with an invalid etag should fail
err = b.SetServeConfig(nil, "abc")
if !errors.Is(err, ErrETagMismatch) {
t.Fatal("expected an error but got nil")
}
// a new config with no etag should succeed
conf := &ipn.ServeConfig{
Web: map[ipn.HostPort]*ipn.WebServerConfig{
"example.ts.net:443": {Handlers: map[string]*ipn.HTTPHandler{
"/": {Proxy: "http://127.0.0.1:3000"},
}},
},
}
err = b.SetServeConfig(conf, getEtag(t, nil))
if err != nil {
t.Fatal(err)
}
confView := b.ServeConfig()
etag := getEtag(t, confView)
if etag == "" {
t.Fatal("expected to get an etag but got an empty string")
}
conf = confView.AsStruct()
mak.Set(&conf.AllowFunnel, "example.ts.net:443", true)
// replacing an existing config with an invalid etag should fail
err = b.SetServeConfig(conf, "invalid etag")
if !errors.Is(err, ErrETagMismatch) {
t.Fatalf("expected an etag mismatch error but got %v", err)
}
// replacing an existing config with a valid etag should succeed
err = b.SetServeConfig(conf, etag)
if err != nil {
t.Fatal(err)
}
// replacing an existing config with a previous etag should fail
err = b.SetServeConfig(nil, etag)
if !errors.Is(err, ErrETagMismatch) {
t.Fatalf("expected an etag mismatch error but got %v", err)
}
// replacing an existing config with the new etag should succeed
newCfg := b.ServeConfig()
etag = getEtag(t, newCfg)
err = b.SetServeConfig(nil, etag)
if err != nil {
t.Fatal(err)
}
}
func TestServeHTTPProxy(t *testing.T) {
b := newTestBackend(t)
// Start test serve endpoint.
testServ := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
// Piping all the headers through the response writer
// so we can check their values in tests below.
for key, val := range r.Header {
w.Header().Add(key, strings.Join(val, ","))
}
},
))
defer testServ.Close()
conf := &ipn.ServeConfig{
Web: map[ipn.HostPort]*ipn.WebServerConfig{
"example.ts.net:443": {Handlers: map[string]*ipn.HTTPHandler{
"/": {Proxy: testServ.URL},
}},
},
}
if err := b.SetServeConfig(conf, ""); err != nil {
t.Fatal(err)
}
type headerCheck struct {
header string
want string
}
tests := []struct {
name string
srcIP string
wantHeaders []headerCheck
}{
{
name: "request-from-user-within-tailnet",
srcIP: "100.150.151.152",
wantHeaders: []headerCheck{
{"X-Forwarded-Proto", "https"},
{"X-Forwarded-For", "100.150.151.152"},
{"Tailscale-User-Login", "someone@example.com"},
{"Tailscale-User-Name", "Some One"},
{"Tailscale-User-Profile-Pic", "https://example.com/photo.jpg"},
{"Tailscale-Headers-Info", "https://tailscale.com/s/serve-headers"},
},
},
{
name: "request-from-tagged-node-within-tailnet",
srcIP: "100.150.151.153",
wantHeaders: []headerCheck{
{"X-Forwarded-Proto", "https"},
{"X-Forwarded-For", "100.150.151.153"},
{"Tailscale-User-Login", ""},
{"Tailscale-User-Name", ""},
{"Tailscale-User-Profile-Pic", ""},
{"Tailscale-Headers-Info", ""},
},
},
{
name: "request-from-outside-tailnet",
srcIP: "100.160.161.162",
wantHeaders: []headerCheck{
{"X-Forwarded-Proto", "https"},
{"X-Forwarded-For", "100.160.161.162"},
{"Tailscale-User-Login", ""},
{"Tailscale-User-Name", ""},
{"Tailscale-User-Profile-Pic", ""},
{"Tailscale-Headers-Info", ""},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := &http.Request{
URL: &url.URL{Path: "/"},
TLS: &tls.ConnectionState{ServerName: "example.ts.net"},
}
req = req.WithContext(context.WithValue(req.Context(), serveHTTPContextKey{}, &serveHTTPContext{
DestPort: 443,
SrcAddr: netip.MustParseAddrPort(tt.srcIP + ":1234"), // random src port for tests
}))
w := httptest.NewRecorder()
b.serveWebHandler(w, req)
// Verify the headers.
h := w.Result().Header
for _, c := range tt.wantHeaders {
if got := h.Get(c.header); got != c.want {
t.Errorf("invalid %q header; want=%q, got=%q", c.header, c.want, got)
}
}
})
}
}
func newTestBackend(t *testing.T) *LocalBackend {
sys := &tsd.System{}
e, err := wgengine.NewUserspaceEngine(t.Logf, wgengine.Config{SetSubsystem: sys.Set})
if err != nil {
t.Fatal(err)
}
sys.Set(e)
sys.Set(new(mem.Store))
b, err := NewLocalBackend(t.Logf, logid.PublicID{}, sys, 0)
if err != nil {
t.Fatal(err)
}
t.Cleanup(b.Shutdown)
dir := t.TempDir()
b.SetVarRoot(dir)
pm := must.Get(newProfileManager(new(mem.Store), t.Logf))
pm.currentProfile = &ipn.LoginProfile{ID: "id0"}
b.pm = pm
b.netMap = &netmap.NetworkMap{
SelfNode: (&tailcfg.Node{
Name: "example.ts.net",
}).View(),
UserProfiles: map[tailcfg.UserID]tailcfg.UserProfile{
tailcfg.UserID(1): {
LoginName: "someone@example.com",
DisplayName: "Some One",
ProfilePicURL: "https://example.com/photo.jpg",
},
},
}
b.peers = map[tailcfg.NodeID]tailcfg.NodeView{
152: (&tailcfg.Node{
ID: 152,
ComputedName: "some-peer",
User: tailcfg.UserID(1),
}).View(),
153: (&tailcfg.Node{
ID: 153,
ComputedName: "some-tagged-peer",
Tags: []string{"tag:server", "tag:test"},
User: tailcfg.UserID(1),
}).View(),
}
b.nodeByAddr = map[netip.Addr]tailcfg.NodeID{
netip.MustParseAddr("100.150.151.152"): 152,
netip.MustParseAddr("100.150.151.153"): 153,
}
return b
}
func TestServeFileOrDirectory(t *testing.T) {
td := t.TempDir()
writeFile := func(suffix, contents string) {
if err := os.WriteFile(filepath.Join(td, suffix), []byte(contents), 0600); err != nil {
t.Fatal(err)
}
}
writeFile("foo", "this is foo")
writeFile("bar", "this is bar")
os.MkdirAll(filepath.Join(td, "subdir"), 0700)
writeFile("subdir/file-a", "this is A")
writeFile("subdir/file-b", "this is B")
writeFile("subdir/file-c", "this is C")
contains := func(subs ...string) func([]byte, *http.Response) error {
return func(resBody []byte, res *http.Response) error {
for _, sub := range subs {
if !bytes.Contains(resBody, []byte(sub)) {
return fmt.Errorf("response body does not contain %q: %s", sub, resBody)
}
}
return nil
}
}
isStatus := func(wantCode int) func([]byte, *http.Response) error {
return func(resBody []byte, res *http.Response) error {
if res.StatusCode != wantCode {
return fmt.Errorf("response status = %d; want %d", res.StatusCode, wantCode)
}
return nil
}
}
isRedirect := func(wantLocation string) func([]byte, *http.Response) error {
return func(resBody []byte, res *http.Response) error {
switch res.StatusCode {
case 301, 302, 303, 307, 308:
if got := res.Header.Get("Location"); got != wantLocation {
return fmt.Errorf("got Location = %q; want %q", got, wantLocation)
}
default:
return fmt.Errorf("response status = %d; want redirect. body: %s", res.StatusCode, resBody)
}
return nil
}
}
b := &LocalBackend{}
tests := []struct {
req string
mount string
want func(resBody []byte, res *http.Response) error
}{
// Mounted at /
{"/", "/", contains("foo", "bar", "subdir")},
{"/../../.../../../../../../../etc/passwd", "/", isStatus(404)},
{"/foo", "/", contains("this is foo")},
{"/bar", "/", contains("this is bar")},
{"/bar/inside-file", "/", isStatus(404)},
{"/subdir", "/", isRedirect("/subdir/")},
{"/subdir/", "/", contains("file-a", "file-b", "file-c")},
{"/subdir/file-a", "/", contains("this is A")},
{"/subdir/file-z", "/", isStatus(404)},
{"/doc", "/doc/", isRedirect("/doc/")},
{"/doc/", "/doc/", contains("foo", "bar", "subdir")},
{"/doc/../../.../../../../../../../etc/passwd", "/doc/", isStatus(404)},
{"/doc/foo", "/doc/", contains("this is foo")},
{"/doc/bar", "/doc/", contains("this is bar")},
{"/doc/bar/inside-file", "/doc/", isStatus(404)},
{"/doc/subdir", "/doc/", isRedirect("/doc/subdir/")},
{"/doc/subdir/", "/doc/", contains("file-a", "file-b", "file-c")},
{"/doc/subdir/file-a", "/doc/", contains("this is A")},
{"/doc/subdir/file-z", "/doc/", isStatus(404)},
}
for _, tt := range tests {
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", tt.req, nil)
b.serveFileOrDirectory(rec, req, td, tt.mount)
if tt.want == nil {
t.Errorf("no want for path %q", tt.req)
return
}
if err := tt.want(rec.Body.Bytes(), rec.Result()); err != nil {
t.Errorf("error for req %q (mount %v): %v", tt.req, tt.mount, err)
}
}
}