You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
tailscale/ipn/ipnlocal/c2n_test.go

318 lines
8.2 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package ipnlocal
import (
"bytes"
"cmp"
"crypto/x509"
"encoding/json"
"fmt"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"reflect"
"testing"
"time"
"tailscale.com/ipn/store/mem"
"tailscale.com/tailcfg"
"tailscale.com/tstest"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/types/netmap"
"tailscale.com/types/opt"
"tailscale.com/types/views"
"tailscale.com/util/must"
gcmp "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)
func TestHandleC2NTLSCertStatus(t *testing.T) {
b := &LocalBackend{
store: &mem.Store{},
varRoot: t.TempDir(),
}
certDir, err := b.certDir()
if err != nil {
t.Fatalf("certDir error: %v", err)
}
if _, err := b.getCertStore(); err != nil {
t.Fatalf("getCertStore error: %v", err)
}
testRoot, err := certTestFS.ReadFile("testdata/rootCA.pem")
if err != nil {
t.Fatal(err)
}
roots := x509.NewCertPool()
if !roots.AppendCertsFromPEM(testRoot) {
t.Fatal("Unable to add test CA to the cert pool")
}
testX509Roots = roots
defer func() { testX509Roots = nil }()
tests := []struct {
name string
domain string
copyFile bool // copy testdata/example.com.pem to the certDir
wantStatus int // 0 means 200
wantError string // wanted non-JSON non-200 error
now time.Time
want *tailcfg.C2NTLSCertInfo
}{
{
name: "no domain",
wantStatus: 400,
wantError: "no 'domain'\n",
},
{
name: "missing",
domain: "example.com",
want: &tailcfg.C2NTLSCertInfo{
Error: "no certificate",
Missing: true,
},
},
{
name: "valid",
domain: "example.com",
now: time.Date(2023, time.February, 20, 0, 0, 0, 0, time.UTC),
copyFile: true,
want: &tailcfg.C2NTLSCertInfo{
Valid: true,
NotBefore: "2023-02-07T20:34:18Z",
NotAfter: "2025-05-07T19:34:18Z",
},
},
{
name: "expired",
domain: "example.com",
now: time.Date(2030, time.February, 20, 0, 0, 0, 0, time.UTC),
copyFile: true,
want: &tailcfg.C2NTLSCertInfo{
Error: "cert expired",
Expired: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.RemoveAll(certDir) // reset per test
if tt.copyFile {
os.MkdirAll(certDir, 0755)
if err := os.WriteFile(filepath.Join(certDir, "example.com.crt"),
must.Get(os.ReadFile("testdata/example.com.pem")), 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(certDir, "example.com.key"),
must.Get(os.ReadFile("testdata/example.com-key.pem")), 0644); err != nil {
t.Fatal(err)
}
}
b.clock = tstest.NewClock(tstest.ClockOpts{
Start: tt.now,
})
rec := httptest.NewRecorder()
handleC2NTLSCertStatus(b, rec, httptest.NewRequest("GET", "/tls-cert-status?domain="+url.QueryEscape(tt.domain), nil))
res := rec.Result()
wantStatus := cmp.Or(tt.wantStatus, 200)
if res.StatusCode != wantStatus {
t.Fatalf("status code = %v; want %v. Body: %s", res.Status, wantStatus, rec.Body.Bytes())
}
if wantStatus == 200 {
var got tailcfg.C2NTLSCertInfo
if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil {
t.Fatalf("bad JSON: %v", err)
}
if !reflect.DeepEqual(&got, tt.want) {
t.Errorf("got %v; want %v", logger.AsJSON(got), logger.AsJSON(tt.want))
}
} else if tt.wantError != "" {
if got := rec.Body.String(); got != tt.wantError {
t.Errorf("body = %q; want %q", got, tt.wantError)
}
}
})
}
}
// reflectNonzero returns a non-zero value for a given reflect.Value.
func reflectNonzero(t reflect.Type) reflect.Value {
switch t.Kind() {
case reflect.Bool:
return reflect.ValueOf(true)
case reflect.String:
if reflect.TypeFor[opt.Bool]() == t {
return reflect.ValueOf("true").Convert(t)
}
return reflect.ValueOf("foo").Convert(t)
case reflect.Int64:
return reflect.ValueOf(int64(1)).Convert(t)
case reflect.Slice:
return reflect.MakeSlice(t, 1, 1)
case reflect.Ptr:
return reflect.New(t.Elem())
case reflect.Map:
return reflect.MakeMap(t)
case reflect.Struct:
switch t {
case reflect.TypeFor[key.NodePrivate]():
return reflect.ValueOf(key.NewNode())
}
}
panic(fmt.Sprintf("unhandled %v", t))
}
// setFieldsToRedact sets fields in the given netmap to non-zero values
// according to the fieldMap, which maps field names to whether they
// should be reset (true) or not (false).
func setFieldsToRedact(t *testing.T, nm *netmap.NetworkMap, fieldMap map[string]bool) {
t.Helper()
v := reflect.ValueOf(nm).Elem()
for i := range v.NumField() {
name := v.Type().Field(i).Name
f := v.Field(i)
if !f.CanSet() {
continue
}
shouldReset, ok := fieldMap[name]
if !ok {
t.Errorf("fieldMap missing field %q", name)
}
if shouldReset {
f.Set(reflectNonzero(f.Type()))
}
}
}
func TestRedactNetmapPrivateKeys(t *testing.T) {
fieldMap := map[string]bool{
// Private fields (should be redacted):
"PrivateKey": true,
// Public fields (should not be redacted):
"AllCaps": false,
"CollectServices": false,
"DERPMap": false,
"DNS": false,
"DisplayMessages": false,
"Domain": false,
"DomainAuditLogID": false,
"Expiry": false,
"MachineKey": false,
"Name": false,
"NodeKey": false,
"PacketFilter": false,
"PacketFilterRules": false,
"Peers": false,
"SSHPolicy": false,
"SelfNode": false,
"TKAEnabled": false,
"TKAHead": false,
"UserProfiles": false,
}
nm := &netmap.NetworkMap{}
setFieldsToRedact(t, nm, fieldMap)
got, _ := redactNetmapPrivateKeys(nm)
if !reflect.DeepEqual(got, &netmap.NetworkMap{}) {
t.Errorf("redacted netmap is not empty: %+v", got)
}
}
func TestHandleC2NDebugNetmap(t *testing.T) {
nm := &netmap.NetworkMap{
Name: "myhost",
SelfNode: (&tailcfg.Node{
ID: 100,
Name: "myhost",
StableID: "deadbeef",
Key: key.NewNode().Public(),
Hostinfo: (&tailcfg.Hostinfo{Hostname: "myhost"}).View(),
}).View(),
Peers: []tailcfg.NodeView{
(&tailcfg.Node{
ID: 101,
Name: "peer1",
StableID: "deadbeef",
Key: key.NewNode().Public(),
Hostinfo: (&tailcfg.Hostinfo{Hostname: "peer1"}).View(),
}).View(),
},
PrivateKey: key.NewNode(),
}
withoutPrivateKey := *nm
withoutPrivateKey.PrivateKey = key.NodePrivate{}
for _, tt := range []struct {
name string
req *tailcfg.C2NDebugNetmapRequest
want *netmap.NetworkMap
}{
{
name: "simple_get",
want: &withoutPrivateKey,
},
{
name: "post_no_omit",
req: &tailcfg.C2NDebugNetmapRequest{},
want: &withoutPrivateKey,
},
{
name: "post_omit_peers_and_name",
req: &tailcfg.C2NDebugNetmapRequest{OmitFields: []string{"Peers", "Name"}},
want: &netmap.NetworkMap{
SelfNode: nm.SelfNode,
},
},
{
name: "post_omit_nonexistent_field",
req: &tailcfg.C2NDebugNetmapRequest{OmitFields: []string{"ThisFieldDoesNotExist"}},
want: &withoutPrivateKey,
},
} {
t.Run(tt.name, func(t *testing.T) {
b := newTestLocalBackend(t)
b.currentNode().SetNetMap(nm)
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/debug/netmap", nil)
if tt.req != nil {
b, err := json.Marshal(tt.req)
if err != nil {
t.Fatalf("json.Marshal: %v", err)
}
req = httptest.NewRequest("POST", "/debug/netmap", bytes.NewReader(b))
}
handleC2NDebugNetMap(b, rec, req)
res := rec.Result()
wantStatus := 200
if res.StatusCode != wantStatus {
t.Fatalf("status code = %v; want %v. Body: %s", res.Status, wantStatus, rec.Body.Bytes())
}
var resp tailcfg.C2NDebugNetmapResponse
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("bad JSON: %v", err)
}
got := &netmap.NetworkMap{}
if err := json.Unmarshal(resp.Current, got); err != nil {
t.Fatalf("bad JSON: %v", err)
}
if diff := gcmp.Diff(tt.want, got,
gcmp.AllowUnexported(netmap.NetworkMap{}, key.NodePublic{}, views.Slice[tailcfg.FilterRule]{}),
cmpopts.EquateComparable(key.MachinePublic{}),
); diff != "" {
t.Errorf("netmap mismatch (-want +got):\n%s", diff)
}
})
}
}