From 2690b4762f9f6eded9857acfd70ab4a913aebcc1 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 5 Dec 2024 15:25:42 -0800 Subject: [PATCH] Revert "VERSION.txt: this is v1.78.0" This reverts commit 0267fe83b200f1702a2fa0a395442c02a053fadb. Reason: it converted the tree to Windows line endings. Updates #14299 Change-Id: I2271a61d43e99bd0bbcf9f4831e8783e570ba08a Signed-off-by: Brad Fitzpatrick --- .bencher/config.yaml | 2 +- .gitattributes | 4 +- .github/ISSUE_TEMPLATE/bug_report.yml | 162 +- .github/ISSUE_TEMPLATE/config.yml | 14 +- .github/ISSUE_TEMPLATE/feature_request.yml | 84 +- .github/dependabot.yml | 42 +- AUTHORS | 34 +- CODEOWNERS | 2 +- CODE_OF_CONDUCT.md | 270 ++-- LICENSE | 56 +- PATENTS | 48 +- SECURITY.md | 16 +- VERSION.txt | 2 +- atomicfile/atomicfile.go | 102 +- atomicfile/atomicfile_test.go | 94 +- chirp/chirp.go | 326 ++-- chirp/chirp_test.go | 384 ++--- client/tailscale/apitype/controltype.go | 38 +- client/tailscale/dns.go | 466 +++--- client/tailscale/example/servetls/servetls.go | 56 +- client/tailscale/keys.go | 332 ++-- client/tailscale/routes.go | 190 +-- client/tailscale/tailnet.go | 84 +- client/web/qnap.go | 254 +-- client/web/src/assets/icons/arrow-right.svg | 8 +- .../web/src/assets/icons/arrow-up-circle.svg | 10 +- client/web/src/assets/icons/check-circle.svg | 8 +- client/web/src/assets/icons/check.svg | 6 +- client/web/src/assets/icons/chevron-down.svg | 6 +- client/web/src/assets/icons/eye.svg | 22 +- client/web/src/assets/icons/search.svg | 8 +- .../web/src/assets/icons/tailscale-icon.svg | 36 +- .../web/src/assets/icons/tailscale-logo.svg | 40 +- client/web/src/assets/icons/user.svg | 8 +- client/web/src/assets/icons/x-circle.svg | 10 +- client/web/synology.go | 118 +- clientupdate/distsign/distsign.go | 972 +++++------ clientupdate/distsign/roots.go | 108 +- clientupdate/distsign/roots/crawshaw-root.pem | 6 +- .../roots/distsign-prod-root-1-pub.pem | 6 +- clientupdate/distsign/roots_test.go | 32 +- cmd/addlicense/main.go | 146 +- cmd/cloner/cloner_test.go | 120 +- cmd/containerboot/test_tailscale.sh | 16 +- cmd/containerboot/test_tailscaled.sh | 76 +- cmd/get-authkey/.gitignore | 2 +- cmd/gitops-pusher/.gitignore | 2 +- cmd/gitops-pusher/README.md | 96 +- cmd/gitops-pusher/cache.go | 132 +- cmd/gitops-pusher/gitops-pusher_test.go | 110 +- cmd/k8s-operator/deploy/chart/.helmignore | 46 +- cmd/k8s-operator/deploy/chart/Chart.yaml | 58 +- .../chart/templates/apiserverproxy-rbac.yaml | 52 +- .../deploy/chart/templates/oauth-secret.yaml | 26 +- .../deploy/manifests/authproxy-rbac.yaml | 46 +- cmd/mkmanifest/main.go | 102 +- cmd/mkpkg/main.go | 268 ++-- cmd/mkversion/mkversion.go | 88 +- cmd/nardump/README.md | 14 +- cmd/nardump/nardump.go | 368 ++--- cmd/nginx-auth/.gitignore | 8 +- cmd/nginx-auth/README.md | 322 ++-- cmd/nginx-auth/deb/postinst.sh | 28 +- cmd/nginx-auth/deb/postrm.sh | 38 +- cmd/nginx-auth/deb/prerm.sh | 16 +- cmd/nginx-auth/mkdeb.sh | 64 +- cmd/nginx-auth/nginx-auth.go | 256 +-- cmd/nginx-auth/rpm/postrm.sh | 18 +- cmd/nginx-auth/rpm/prerm.sh | 18 +- cmd/nginx-auth/tailscale.nginx-auth.service | 22 +- cmd/nginx-auth/tailscale.nginx-auth.socket | 16 +- cmd/pgproxy/README.md | 84 +- cmd/printdep/printdep.go | 82 +- cmd/sniproxy/.gitignore | 2 +- cmd/sniproxy/handlers_test.go | 318 ++-- cmd/sniproxy/server.go | 654 ++++---- cmd/sniproxy/server_test.go | 190 +-- cmd/sniproxy/sniproxy.go | 582 +++---- cmd/speedtest/speedtest.go | 242 +-- cmd/ssh-auth-none-demo/ssh-auth-none-demo.go | 374 ++--- cmd/sync-containers/main.go | 428 ++--- cmd/tailscale/cli/diag.go | 148 +- cmd/tailscale/cli/diag_other.go | 30 +- cmd/tailscale/cli/set_test.go | 262 +-- cmd/tailscale/cli/ssh_exec.go | 48 +- cmd/tailscale/cli/ssh_exec_js.go | 32 +- cmd/tailscale/cli/ssh_exec_windows.go | 74 +- cmd/tailscale/cli/ssh_unix.go | 98 +- cmd/tailscale/cli/web_test.go | 90 +- cmd/tailscale/generate.go | 16 +- cmd/tailscale/tailscale.go | 52 +- cmd/tailscale/windows-manifest.xml | 26 +- cmd/tailscaled/childproc/childproc.go | 38 +- cmd/tailscaled/generate.go | 16 +- cmd/tailscaled/install_darwin.go | 398 ++--- cmd/tailscaled/install_windows.go | 248 +-- cmd/tailscaled/proxy.go | 160 +- cmd/tailscaled/sigpipe.go | 24 +- cmd/tailscaled/tailscaled.defaults | 16 +- cmd/tailscaled/tailscaled.openrc | 50 +- cmd/tailscaled/tailscaled_bird.go | 34 +- cmd/tailscaled/tailscaled_notwindows.go | 28 +- cmd/tailscaled/windows-manifest.xml | 26 +- cmd/tailscaled/with_cli.go | 46 +- cmd/testwrapper/args_test.go | 194 +-- cmd/testwrapper/flakytest/flakytest.go | 88 +- cmd/testwrapper/flakytest/flakytest_test.go | 86 +- cmd/tsconnect/.gitignore | 6 +- cmd/tsconnect/README.md | 98 +- cmd/tsconnect/README.pkg.md | 6 +- cmd/tsconnect/build-pkg.go | 198 +-- cmd/tsconnect/dev-pkg.go | 36 +- cmd/tsconnect/dev.go | 36 +- cmd/tsconnect/dist/placeholder | 4 +- cmd/tsconnect/index.html | 40 +- cmd/tsconnect/package.json | 50 +- cmd/tsconnect/package.json.tmpl | 32 +- cmd/tsconnect/serve.go | 288 ++-- cmd/tsconnect/src/app/app.tsx | 294 ++-- cmd/tsconnect/src/app/go-panic-display.tsx | 40 +- cmd/tsconnect/src/app/header.tsx | 74 +- cmd/tsconnect/src/app/index.css | 148 +- cmd/tsconnect/src/app/index.ts | 72 +- cmd/tsconnect/src/app/ssh.tsx | 314 ++-- cmd/tsconnect/src/app/url-display.tsx | 62 +- cmd/tsconnect/src/lib/js-state-store.ts | 26 +- cmd/tsconnect/src/pkg/pkg.css | 16 +- cmd/tsconnect/src/pkg/pkg.ts | 80 +- cmd/tsconnect/src/types/esbuild.d.ts | 28 +- cmd/tsconnect/src/types/wasm_js.d.ts | 206 +-- cmd/tsconnect/tailwind.config.js | 16 +- cmd/tsconnect/tsconfig.json | 30 +- cmd/tsconnect/tsconnect.go | 142 +- cmd/tsconnect/yarn.lock | 1426 ++++++++--------- cmd/tsshd/tsshd.go | 24 +- control/controlbase/conn.go | 816 +++++----- control/controlbase/handshake.go | 988 ++++++------ control/controlbase/interop_test.go | 512 +++--- control/controlbase/messages.go | 174 +- control/controlclient/sign.go | 84 +- control/controlclient/sign_supported_test.go | 472 +++--- control/controlclient/sign_unsupported.go | 32 +- control/controlclient/status.go | 250 +-- control/controlhttp/client_common.go | 34 +- derp/README.md | 120 +- derp/testdata/example_ss.txt | 16 +- disco/disco_fuzzer.go | 34 +- disco/disco_test.go | 236 +-- disco/pcap.go | 80 +- docs/bird/sample_bird.conf | 32 +- docs/bird/tailscale_bird.conf | 8 +- docs/k8s/Makefile | 50 +- docs/k8s/rolebinding.yaml | 26 +- docs/k8s/sa.yaml | 12 +- docs/sysv/tailscale.init | 126 +- doctor/doctor.go | 158 +- doctor/doctor_test.go | 98 +- doctor/permissions/permissions_bsd.go | 46 +- doctor/permissions/permissions_linux.go | 124 +- doctor/permissions/permissions_other.go | 34 +- doctor/permissions/permissions_test.go | 24 +- doctor/routetable/routetable.go | 68 +- envknob/envknob_nottest.go | 32 +- envknob/envknob_testable.go | 46 +- envknob/logknob/logknob.go | 170 +- envknob/logknob/logknob_test.go | 204 +-- gomod_test.go | 50 +- hostinfo/hostinfo_darwin.go | 42 +- hostinfo/hostinfo_freebsd.go | 128 +- hostinfo/hostinfo_test.go | 102 +- hostinfo/hostinfo_uname.go | 76 +- hostinfo/wol.go | 212 +-- ipn/ipnlocal/breaktcp_darwin.go | 60 +- ipn/ipnlocal/breaktcp_linux.go | 60 +- ipn/ipnlocal/expiry_test.go | 602 +++---- ipn/ipnlocal/peerapi_h2c.go | 40 +- ipn/ipnlocal/testdata/example.com-key.pem | 54 +- ipn/ipnlocal/testdata/example.com.pem | 50 +- ipn/ipnlocal/testdata/rootCA.pem | 58 +- ipn/ipnserver/proxyconnect_js.go | 20 +- ipn/ipnserver/server_test.go | 92 +- ipn/localapi/disabled_stubs.go | 30 +- ipn/localapi/pprof.go | 56 +- ipn/policy/policy.go | 94 +- ipn/store/awsstore/store_aws.go | 372 ++--- ipn/store/awsstore/store_aws_stub.go | 36 +- ipn/store/awsstore/store_aws_test.go | 328 ++-- ipn/store/stores_test.go | 358 ++--- ipn/store_test.go | 96 +- jsondb/db.go | 114 +- jsondb/db_test.go | 110 +- licenses/licenses.go | 42 +- log/filelogger/log.go | 456 +++--- log/filelogger/log_test.go | 54 +- logpolicy/logpolicy_test.go | 72 +- logtail/.gitignore | 12 +- logtail/README.md | 18 +- logtail/api.md | 388 ++--- logtail/example/logreprocess/demo.sh | 172 +- logtail/example/logreprocess/logreprocess.go | 230 +-- logtail/example/logtail/logtail.go | 92 +- logtail/filch/filch.go | 568 +++---- logtail/filch/filch_stub.go | 46 +- logtail/filch/filch_unix.go | 60 +- logtail/filch/filch_windows.go | 86 +- metrics/fds_linux.go | 82 +- metrics/fds_notlinux.go | 16 +- metrics/metrics.go | 326 ++-- net/art/art_test.go | 40 +- net/art/table.go | 1282 +++++++-------- net/dns/debian_resolvconf.go | 368 ++--- net/dns/direct_notlinux.go | 20 +- net/dns/flush_default.go | 20 +- net/dns/ini.go | 60 +- net/dns/ini_test.go | 76 +- net/dns/noop.go | 34 +- net/dns/resolvconf-workaround.sh | 124 +- net/dns/resolvconf.go | 60 +- net/dns/resolvconffile/resolvconffile.go | 248 +-- net/dns/resolvconfpath_default.go | 22 +- net/dns/resolvconfpath_gokrazy.go | 22 +- net/dns/resolver/doh_test.go | 198 +-- net/dns/resolver/macios_ext.go | 52 +- net/dns/resolver/tsdns_server_test.go | 666 ++++---- net/dns/utf.go | 110 +- net/dns/utf_test.go | 48 +- net/dnscache/dnscache_test.go | 484 +++--- net/dnscache/messagecache_test.go | 582 +++---- net/dnsfallback/update-dns-fallbacks.go | 90 +- net/memnet/conn.go | 228 +-- net/memnet/conn_test.go | 42 +- net/memnet/listener.go | 200 +-- net/memnet/listener_test.go | 66 +- net/memnet/memnet.go | 16 +- net/memnet/pipe.go | 488 +++--- net/memnet/pipe_test.go | 234 +-- net/netaddr/netaddr.go | 98 +- net/neterror/neterror.go | 164 +- net/neterror/neterror_linux.go | 52 +- net/neterror/neterror_linux_test.go | 108 +- net/neterror/neterror_windows.go | 32 +- net/netkernelconf/netkernelconf.go | 10 +- net/netknob/netknob.go | 58 +- net/netmon/netmon_darwin_test.go | 54 +- net/netmon/netmon_freebsd.go | 112 +- net/netmon/netmon_linux.go | 580 +++---- net/netmon/netmon_polling.go | 42 +- net/netmon/polling.go | 172 +- net/netns/netns_android.go | 150 +- net/netns/netns_default.go | 44 +- net/netns/netns_linux_test.go | 28 +- net/netns/netns_test.go | 156 +- net/netns/socks.go | 38 +- net/netstat/netstat.go | 70 +- net/netstat/netstat_noimpl.go | 28 +- net/netstat/netstat_test.go | 42 +- net/packet/doc.go | 30 +- net/packet/header.go | 132 +- net/packet/icmp.go | 56 +- net/packet/icmp6_test.go | 158 +- net/packet/ip4.go | 232 +-- net/packet/ip6.go | 152 +- net/packet/tsmp_test.go | 146 +- net/packet/udp4.go | 116 +- net/packet/udp6.go | 108 +- net/ping/ping.go | 686 ++++---- net/ping/ping_test.go | 700 ++++---- net/portmapper/pcp_test.go | 124 +- net/proxymux/mux.go | 288 ++-- net/routetable/routetable_darwin.go | 72 +- net/routetable/routetable_freebsd.go | 56 +- net/routetable/routetable_other.go | 34 +- net/sockstats/sockstats.go | 242 +-- net/sockstats/sockstats_noop.go | 76 +- net/sockstats/sockstats_tsgo_darwin.go | 60 +- net/speedtest/speedtest.go | 174 +- net/speedtest/speedtest_client.go | 82 +- net/speedtest/speedtest_server.go | 292 ++-- net/speedtest/speedtest_test.go | 166 +- net/stun/stun.go | 624 ++++---- net/stun/stun_fuzzer.go | 24 +- net/tcpinfo/tcpinfo.go | 102 +- net/tcpinfo/tcpinfo_darwin.go | 66 +- net/tcpinfo/tcpinfo_linux.go | 66 +- net/tcpinfo/tcpinfo_other.go | 30 +- net/tlsdial/deps_test.go | 16 +- net/tsdial/dnsmap_test.go | 250 +-- net/tsdial/dohclient.go | 200 +-- net/tsdial/dohclient_test.go | 62 +- net/tshttpproxy/mksyscall.go | 22 +- net/tshttpproxy/tshttpproxy_linux.go | 48 +- net/tshttpproxy/tshttpproxy_synology_test.go | 752 ++++----- net/tshttpproxy/tshttpproxy_windows.go | 552 +++---- net/tstun/fake.go | 116 +- net/tstun/ifstatus_noop.go | 36 +- net/tstun/ifstatus_windows.go | 218 +-- net/tstun/linkattrs_linux.go | 126 +- net/tstun/linkattrs_notlinux.go | 24 +- net/tstun/mtu.go | 322 ++-- net/tstun/mtu_test.go | 198 +-- net/tstun/tun_linux.go | 206 +-- net/tstun/tun_macos.go | 50 +- net/tstun/tun_notwindows.go | 24 +- packages/deb/deb.go | 364 ++--- packages/deb/deb_test.go | 410 ++--- paths/migrate.go | 116 +- paths/paths.go | 184 +-- paths/paths_windows.go | 200 +-- portlist/clean.go | 72 +- portlist/clean_test.go | 114 +- portlist/netstat_test.go | 184 +-- portlist/poller.go | 244 +-- portlist/portlist.go | 160 +- portlist/portlist_macos.go | 460 +++--- portlist/portlist_windows.go | 206 +-- posture/serialnumber_macos.go | 148 +- posture/serialnumber_notmacos_test.go | 76 +- posture/serialnumber_test.go | 32 +- pull-toolchain.sh | 32 +- release/deb/debian.postrm.sh | 34 +- release/deb/debian.prerm.sh | 14 +- release/dist/memoize.go | 172 +- release/dist/synology/files/Tailscale.sc | 10 +- release/dist/synology/files/config | 22 +- release/dist/synology/files/index.cgi | 4 +- release/dist/synology/files/logrotate-dsm6 | 16 +- release/dist/synology/files/logrotate-dsm7 | 16 +- release/dist/synology/files/privilege-dsm6 | 14 +- release/dist/synology/files/privilege-dsm7 | 14 +- .../files/privilege-dsm7.for-package-center | 26 +- release/dist/synology/files/resource | 20 +- .../dist/synology/files/scripts/postupgrade | 4 +- .../dist/synology/files/scripts/preupgrade | 4 +- .../synology/files/scripts/start-stop-status | 258 +-- release/dist/unixpkgs/pkgs.go | 944 +++++------ release/dist/unixpkgs/targets.go | 254 +-- release/release.go | 30 +- release/rpm/rpm.postinst.sh | 82 +- release/rpm/rpm.postrm.sh | 16 +- release/rpm/rpm.prerm.sh | 16 +- safesocket/safesocket_test.go | 24 +- smallzstd/testdata | 28 +- smallzstd/zstd.go | 156 +- syncs/locked.go | 64 +- syncs/locked_test.go | 240 +-- syncs/shardedmap.go | 276 ++-- syncs/shardedmap_test.go | 162 +- tailcfg/proto_port_range.go | 374 ++--- tailcfg/proto_port_range_test.go | 262 +-- tailcfg/tka.go | 528 +++--- taildrop/delete.go | 410 ++--- taildrop/delete_test.go | 304 ++-- taildrop/resume_test.go | 148 +- taildrop/retrieve.go | 356 ++-- tempfork/gliderlabs/ssh/LICENSE | 54 +- tempfork/gliderlabs/ssh/README.md | 192 +-- tempfork/gliderlabs/ssh/agent.go | 166 +- tempfork/gliderlabs/ssh/conn.go | 110 +- tempfork/gliderlabs/ssh/context.go | 328 ++-- tempfork/gliderlabs/ssh/context_test.go | 98 +- tempfork/gliderlabs/ssh/doc.go | 90 +- tempfork/gliderlabs/ssh/example_test.go | 100 +- tempfork/gliderlabs/ssh/options.go | 168 +- tempfork/gliderlabs/ssh/options_test.go | 222 +-- tempfork/gliderlabs/ssh/server.go | 918 +++++------ tempfork/gliderlabs/ssh/server_test.go | 256 +-- tempfork/gliderlabs/ssh/session.go | 772 ++++----- tempfork/gliderlabs/ssh/session_test.go | 880 +++++----- tempfork/gliderlabs/ssh/ssh.go | 312 ++-- tempfork/gliderlabs/ssh/ssh_test.go | 34 +- tempfork/gliderlabs/ssh/tcpip.go | 386 ++--- tempfork/gliderlabs/ssh/tcpip_test.go | 170 +- tempfork/gliderlabs/ssh/util.go | 314 ++-- tempfork/gliderlabs/ssh/wrap.go | 66 +- tempfork/heap/heap.go | 242 +-- tka/aum_test.go | 506 +++--- tka/builder.go | 360 ++--- tka/builder_test.go | 540 +++---- tka/deeplink.go | 442 ++--- tka/deeplink_test.go | 104 +- tka/key.go | 318 ++-- tka/key_test.go | 194 +-- tka/state.go | 630 ++++---- tka/state_test.go | 520 +++--- tka/sync_test.go | 754 ++++----- tka/tailchonk_test.go | 1386 ++++++++-------- tka/tka_test.go | 1308 +++++++-------- tool/binaryen.rev | 2 +- tool/go | 14 +- tool/gocross/env.go | 262 +-- tool/gocross/env_test.go | 198 +-- tool/gocross/exec_other.go | 40 +- tool/gocross/exec_unix.go | 24 +- tool/helm | 138 +- tool/helm.rev | 2 +- tool/node | 130 +- tool/wasm-opt | 148 +- tool/yarn | 86 +- tool/yarn.rev | 2 +- tsnet/example/tshello/tshello.go | 120 +- .../tsnet-http-client/tsnet-http-client.go | 88 +- tsnet/example/web-client/web-client.go | 92 +- tsnet/example_tshello_test.go | 144 +- tstest/allocs.go | 100 +- tstest/archtest/qemu_test.go | 146 +- tstest/clock.go | 1388 ++++++++-------- tstest/deptest/deptest_test.go | 20 +- tstest/integration/gen_deps.go | 130 +- tstest/integration/vms/README.md | 190 +-- tstest/integration/vms/distros.hujson | 78 +- tstest/integration/vms/distros_test.go | 28 +- tstest/integration/vms/dns_tester.go | 108 +- tstest/integration/vms/doc.go | 12 +- tstest/integration/vms/harness_test.go | 484 +++--- tstest/integration/vms/nixos_test.go | 462 +++--- tstest/integration/vms/regex_flag.go | 58 +- tstest/integration/vms/regex_flag_test.go | 42 +- tstest/integration/vms/runner.nix | 178 +- tstest/integration/vms/squid.conf | 76 +- tstest/integration/vms/top_level_test.go | 248 +-- tstest/integration/vms/udp_tester.go | 154 +- tstest/log_test.go | 94 +- tstest/natlab/firewall.go | 312 ++-- tstest/natlab/nat.go | 504 +++--- tstest/tstest.go | 190 +-- tstest/tstest_test.go | 48 +- tstime/mono/mono.go | 254 +-- tstime/rate/rate.go | 180 +-- tstime/tstime.go | 370 ++--- tstime/tstime_test.go | 72 +- tsweb/debug_test.go | 416 ++--- tsweb/promvarz/promvarz_test.go | 76 +- types/appctype/appconnector_test.go | 156 +- types/dnstype/dnstype.go | 136 +- types/empty/message.go | 26 +- types/flagtype/flagtype.go | 90 +- types/ipproto/ipproto.go | 398 ++--- types/key/chal.go | 182 +-- types/key/control.go | 136 +- types/key/control_test.go | 76 +- types/key/disco_test.go | 166 +- types/key/machine.go | 528 +++--- types/key/machine_test.go | 238 +-- types/key/nl_test.go | 96 +- types/lazy/unsync.go | 198 +-- types/lazy/unsync_test.go | 280 ++-- types/logger/rusage.go | 46 +- types/logger/rusage_stub.go | 22 +- types/logger/rusage_syscall.go | 58 +- types/logger/tokenbucket.go | 126 +- types/netlogtype/netlogtype.go | 200 +-- types/netlogtype/netlogtype_test.go | 78 +- types/netmap/netmap_test.go | 636 ++++---- types/nettype/nettype.go | 130 +- types/preftype/netfiltermode.go | 92 +- types/ptr/ptr.go | 20 +- types/structs/structs.go | 30 +- types/tkatype/tkatype.go | 80 +- types/tkatype/tkatype_test.go | 86 +- util/cibuild/cibuild.go | 28 +- util/cstruct/cstruct.go | 356 ++-- util/cstruct/cstruct_example_test.go | 146 +- util/deephash/debug.go | 74 +- util/deephash/pointer.go | 228 +-- util/deephash/pointer_norace.go | 26 +- util/deephash/pointer_race.go | 198 +-- util/deephash/testtype/testtype.go | 30 +- util/dirwalk/dirwalk.go | 106 +- util/dirwalk/dirwalk_linux.go | 334 ++-- util/dirwalk/dirwalk_test.go | 182 +-- util/goroutines/goroutines.go | 186 +-- util/goroutines/goroutines_test.go | 58 +- util/groupmember/groupmember.go | 58 +- util/hashx/block512.go | 394 ++--- util/httphdr/httphdr.go | 394 ++--- util/httphdr/httphdr_test.go | 192 +-- util/httpm/httpm.go | 72 +- util/httpm/httpm_test.go | 74 +- util/jsonutil/types.go | 32 +- util/jsonutil/unmarshal.go | 178 +- util/lineread/lineread.go | 74 +- util/linuxfw/linuxfwtest/linuxfwtest.go | 62 +- .../linuxfwtest/linuxfwtest_unsupported.go | 36 +- util/linuxfw/nftables_types.go | 190 +-- util/mak/mak.go | 140 +- util/mak/mak_test.go | 176 +- util/multierr/multierr.go | 272 ++-- util/must/must.go | 50 +- util/osdiag/mksyscall.go | 26 +- util/osdiag/osdiag_windows_test.go | 256 +-- util/osshare/filesharingstatus_noop.go | 24 +- util/pidowner/pidowner.go | 48 +- util/pidowner/pidowner_noimpl.go | 16 +- util/pidowner/pidowner_windows.go | 70 +- util/precompress/precompress.go | 258 +-- util/quarantine/quarantine.go | 28 +- util/quarantine/quarantine_darwin.go | 112 +- util/quarantine/quarantine_default.go | 28 +- util/quarantine/quarantine_windows.go | 58 +- util/race/race_test.go | 198 +-- util/racebuild/off.go | 16 +- util/racebuild/on.go | 16 +- util/racebuild/racebuild.go | 12 +- util/rands/rands.go | 50 +- util/rands/rands_test.go | 30 +- util/set/handle.go | 56 +- util/set/slice_test.go | 112 +- util/sysresources/memory.go | 20 +- util/sysresources/memory_bsd.go | 32 +- util/sysresources/memory_darwin.go | 32 +- util/sysresources/memory_linux.go | 38 +- util/sysresources/memory_unsupported.go | 16 +- util/sysresources/sysresources.go | 12 +- util/sysresources/sysresources_test.go | 50 +- util/systemd/doc.go | 26 +- util/systemd/systemd_linux.go | 154 +- util/systemd/systemd_nonlinux.go | 18 +- util/testenv/testenv.go | 42 +- util/truncate/truncate_test.go | 72 +- util/uniq/slice.go | 124 +- util/winutil/authenticode/mksyscall.go | 36 +- util/winutil/policy/policy_windows.go | 310 ++-- util/winutil/policy/policy_windows_test.go | 76 +- version/.gitignore | 20 +- version/cmdname.go | 278 ++-- version/cmdname_ios.go | 36 +- version/cmp_test.go | 164 +- version/export_test.go | 28 +- version/print.go | 66 +- version/race.go | 20 +- version/race_off.go | 20 +- version/version_test.go | 102 +- wgengine/bench/bench.go | 818 +++++----- wgengine/bench/bench_test.go | 216 +-- wgengine/bench/trafficgen.go | 518 +++--- wgengine/capture/capture.go | 476 +++--- wgengine/magicsock/blockforever_conn.go | 110 +- wgengine/magicsock/endpoint_default.go | 44 +- wgengine/magicsock/endpoint_stub.go | 26 +- wgengine/magicsock/endpoint_tracker.go | 496 +++--- wgengine/magicsock/magicsock_unix_test.go | 120 +- wgengine/magicsock/peermtu_darwin.go | 102 +- wgengine/magicsock/peermtu_linux.go | 98 +- wgengine/magicsock/peermtu_unix.go | 84 +- wgengine/mem_ios.go | 40 +- wgengine/netstack/netstack_linux.go | 38 +- wgengine/router/runner.go | 240 +-- wgengine/watchdog_js.go | 34 +- wgengine/wgcfg/device.go | 136 +- wgengine/wgcfg/device_test.go | 522 +++--- wgengine/wgcfg/parser.go | 372 ++--- wgengine/winnet/winnet_windows.go | 52 +- words/words.go | 116 +- words/words_test.go | 76 +- 554 files changed, 44582 insertions(+), 44582 deletions(-) diff --git a/.bencher/config.yaml b/.bencher/config.yaml index b60c5c352..220bd9d3b 100644 --- a/.bencher/config.yaml +++ b/.bencher/config.yaml @@ -1 +1 @@ -suppress_failure_on_regression: true +suppress_failure_on_regression: true diff --git a/.gitattributes b/.gitattributes index 38a6b06a3..3eb528782 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,2 @@ -go.mod filter=go-mod -*.go diff=golang +go.mod filter=go-mod +*.go diff=golang diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 688de1444..9163171c9 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -1,81 +1,81 @@ -name: Bug report -description: File a bug report. If you need help, contact support instead -labels: [needs-triage, bug] -body: - - type: markdown - attributes: - value: | - Need help with your tailnet? [Contact support](https://tailscale.com/contact/support) instead. - Otherwise, please check if your bug is [already filed](https://github.com/tailscale/tailscale/issues) before filing a new one. - - type: textarea - id: what-happened - attributes: - label: What is the issue? - description: What happened? What did you expect to happen? - validations: - required: true - - type: textarea - id: steps - attributes: - label: Steps to reproduce - description: What are the steps you took that hit this issue? - validations: - required: false - - type: textarea - id: changes - attributes: - label: Are there any recent changes that introduced the issue? - description: If so, what are those changes? - validations: - required: false - - type: dropdown - id: os - attributes: - label: OS - description: What OS are you using? You may select more than one. - multiple: true - options: - - Linux - - macOS - - Windows - - iOS - - Android - - Synology - - Other - validations: - required: false - - type: input - id: os-version - attributes: - label: OS version - description: What OS version are you using? - placeholder: e.g., Debian 11.0, macOS Big Sur 11.6, Synology DSM 7 - validations: - required: false - - type: input - id: ts-version - attributes: - label: Tailscale version - description: What Tailscale version are you using? - placeholder: e.g., 1.14.4 - validations: - required: false - - type: textarea - id: other-software - attributes: - label: Other software - description: What [other software](https://github.com/tailscale/tailscale/wiki/OtherSoftwareInterop) (networking, security, etc) are you running? - validations: - required: false - - type: input - id: bug-report - attributes: - label: Bug report - description: Please run [`tailscale bugreport`](https://tailscale.com/kb/1080/cli/?q=Cli#bugreport) and share the bug identifier. The identifier is a random string which allows Tailscale support to locate your account and gives a point to focus on when looking for errors. - placeholder: e.g., BUG-1b7641a16971a9cd75822c0ed8043fee70ae88cf05c52981dc220eb96a5c49a8-20210427151443Z-fbcd4fd3a4b7ad94 - validations: - required: false - - type: markdown - attributes: - value: | - Thanks for filing a bug report! +name: Bug report +description: File a bug report. If you need help, contact support instead +labels: [needs-triage, bug] +body: + - type: markdown + attributes: + value: | + Need help with your tailnet? [Contact support](https://tailscale.com/contact/support) instead. + Otherwise, please check if your bug is [already filed](https://github.com/tailscale/tailscale/issues) before filing a new one. + - type: textarea + id: what-happened + attributes: + label: What is the issue? + description: What happened? What did you expect to happen? + validations: + required: true + - type: textarea + id: steps + attributes: + label: Steps to reproduce + description: What are the steps you took that hit this issue? + validations: + required: false + - type: textarea + id: changes + attributes: + label: Are there any recent changes that introduced the issue? + description: If so, what are those changes? + validations: + required: false + - type: dropdown + id: os + attributes: + label: OS + description: What OS are you using? You may select more than one. + multiple: true + options: + - Linux + - macOS + - Windows + - iOS + - Android + - Synology + - Other + validations: + required: false + - type: input + id: os-version + attributes: + label: OS version + description: What OS version are you using? + placeholder: e.g., Debian 11.0, macOS Big Sur 11.6, Synology DSM 7 + validations: + required: false + - type: input + id: ts-version + attributes: + label: Tailscale version + description: What Tailscale version are you using? + placeholder: e.g., 1.14.4 + validations: + required: false + - type: textarea + id: other-software + attributes: + label: Other software + description: What [other software](https://github.com/tailscale/tailscale/wiki/OtherSoftwareInterop) (networking, security, etc) are you running? + validations: + required: false + - type: input + id: bug-report + attributes: + label: Bug report + description: Please run [`tailscale bugreport`](https://tailscale.com/kb/1080/cli/?q=Cli#bugreport) and share the bug identifier. The identifier is a random string which allows Tailscale support to locate your account and gives a point to focus on when looking for errors. + placeholder: e.g., BUG-1b7641a16971a9cd75822c0ed8043fee70ae88cf05c52981dc220eb96a5c49a8-20210427151443Z-fbcd4fd3a4b7ad94 + validations: + required: false + - type: markdown + attributes: + value: | + Thanks for filing a bug report! diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index e3c44b6a1..3f4a31534 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,8 +1,8 @@ -blank_issues_enabled: true -contact_links: - - name: Support - url: https://tailscale.com/contact/support/ - about: Contact us for support - - name: Troubleshooting - url: https://tailscale.com/kb/1023/troubleshooting +blank_issues_enabled: true +contact_links: + - name: Support + url: https://tailscale.com/contact/support/ + about: Contact us for support + - name: Troubleshooting + url: https://tailscale.com/kb/1023/troubleshooting about: See the troubleshooting guide for help addressing common issues \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index 02ecae13c..f75386274 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -1,42 +1,42 @@ -name: Feature request -description: Propose a new feature -title: "FR: " -labels: [needs-triage, fr] -body: - - type: markdown - attributes: - value: | - Please check if your feature request is [already filed](https://github.com/tailscale/tailscale/issues). - Tell us about your idea! - - type: textarea - id: problem - attributes: - label: What are you trying to do? - description: Tell us about the problem you're trying to solve. - validations: - required: false - - type: textarea - id: solution - attributes: - label: How should we solve this? - description: If you have an idea of how you'd like to see this feature work, let us know. - validations: - required: false - - type: textarea - id: alternative - attributes: - label: What is the impact of not solving this? - description: (How) Are you currently working around the issue? - validations: - required: false - - type: textarea - id: context - attributes: - label: Anything else? - description: Any additional context to share, e.g., links - validations: - required: false - - type: markdown - attributes: - value: | - Thanks for filing a feature request! +name: Feature request +description: Propose a new feature +title: "FR: " +labels: [needs-triage, fr] +body: + - type: markdown + attributes: + value: | + Please check if your feature request is [already filed](https://github.com/tailscale/tailscale/issues). + Tell us about your idea! + - type: textarea + id: problem + attributes: + label: What are you trying to do? + description: Tell us about the problem you're trying to solve. + validations: + required: false + - type: textarea + id: solution + attributes: + label: How should we solve this? + description: If you have an idea of how you'd like to see this feature work, let us know. + validations: + required: false + - type: textarea + id: alternative + attributes: + label: What is the impact of not solving this? + description: (How) Are you currently working around the issue? + validations: + required: false + - type: textarea + id: context + attributes: + label: Anything else? + description: Any additional context to share, e.g., links + validations: + required: false + - type: markdown + attributes: + value: | + Thanks for filing a feature request! diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 225132e54..14c912905 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,21 +1,21 @@ -# Documentation for this file can be found at: -# https://docs.github.com/en/code-security/supply-chain-security/keeping-your-dependencies-updated-automatically/configuration-options-for-dependency-updates -version: 2 -updates: - ## Disabled between releases. We reenable it briefly after every - ## stable release, pull in all changes, and close it again so that - ## the tree remains more stable during development and the upstream - ## changes have time to soak before the next release. - # - package-ecosystem: "gomod" - # directory: "/" - # schedule: - # interval: "daily" - # commit-message: - # prefix: "go.mod:" - # open-pull-requests-limit: 100 - - package-ecosystem: "github-actions" - directory: "/" - schedule: - interval: "weekly" - commit-message: - prefix: ".github:" +# Documentation for this file can be found at: +# https://docs.github.com/en/code-security/supply-chain-security/keeping-your-dependencies-updated-automatically/configuration-options-for-dependency-updates +version: 2 +updates: + ## Disabled between releases. We reenable it briefly after every + ## stable release, pull in all changes, and close it again so that + ## the tree remains more stable during development and the upstream + ## changes have time to soak before the next release. + # - package-ecosystem: "gomod" + # directory: "/" + # schedule: + # interval: "daily" + # commit-message: + # prefix: "go.mod:" + # open-pull-requests-limit: 100 + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + commit-message: + prefix: ".github:" diff --git a/AUTHORS b/AUTHORS index 3fafc4492..03d5932c0 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,17 +1,17 @@ -# This is the official list of Tailscale -# authors for copyright purposes. -# -# Names should be added to this file as one of -# Organization's name -# Individual's name -# Individual's name -# -# Please keep the list sorted. -# -# You do not need to add entries to this list, and we don't actively -# populate this list. If you do want to be acknowledged explicitly as -# a copyright holder, though, then please send a PR referencing your -# earlier contributions and clarifying whether it's you or your -# company that owns the rights to your contribution. - -Tailscale Inc. +# This is the official list of Tailscale +# authors for copyright purposes. +# +# Names should be added to this file as one of +# Organization's name +# Individual's name +# Individual's name +# +# Please keep the list sorted. +# +# You do not need to add entries to this list, and we don't actively +# populate this list. If you do want to be acknowledged explicitly as +# a copyright holder, though, then please send a PR referencing your +# earlier contributions and clarifying whether it's you or your +# company that owns the rights to your contribution. + +Tailscale Inc. diff --git a/CODEOWNERS b/CODEOWNERS index 76edf1006..af9b0d9f9 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1 +1 @@ -/tailcfg/ @tailscale/control-protocol-owners +/tailcfg/ @tailscale/control-protocol-owners diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index cf4e6ddbe..be5564ef4 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,135 +1,135 @@ -# Contributor Covenant Code of Conduct - -## Our Pledge - -We as members, contributors, and leaders pledge to make participation -in our community a harassment-free experience for everyone, regardless -of age, body size, visible or invisible disability, ethnicity, sex -characteristics, gender identity and expression, level of experience, -education, socio-economic status, nationality, personal appearance, -race, religion, or sexual identity and orientation. - -We pledge to act and interact in ways that contribute to an open, -welcoming, diverse, inclusive, and healthy community. - -## Our Standards - -Examples of behavior that contributes to a positive environment for -our community include: - -* Demonstrating empathy and kindness toward other people -* Being respectful of differing opinions, viewpoints, and experiences -* Giving and gracefully accepting constructive feedback -* Accepting responsibility and apologizing to those affected by our - mistakes, and learning from the experience -* Focusing on what is best not just for us as individuals, but for the - overall community - -Examples of unacceptable behavior include: - -* The use of sexualized language or imagery, and sexual attention or - advances of any kind -* Trolling, insulting or derogatory comments, and personal or - political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or email - address, without their explicit permission -* Other conduct which could reasonably be considered inappropriate in - a professional setting - -## Enforcement Responsibilities - -Community leaders are responsible for clarifying and enforcing our -standards of acceptable behavior and will take appropriate and fair -corrective action in response to any behavior that they deem -inappropriate, threatening, offensive, or harmful. - -Community leaders have the right and responsibility to remove, edit, -or reject comments, commits, code, wiki edits, issues, and other -contributions that are not aligned to this Code of Conduct, and will -communicate reasons for moderation decisions when appropriate. - -## Scope - -This Code of Conduct applies within all community spaces, and also -applies when an individual is officially representing the community in -public spaces. Examples of representing our community include using an -official e-mail address, posting via an official social media account, -or acting as an appointed representative at an online or offline -event. - -## Enforcement - -Instances of abusive, harassing, or otherwise unacceptable behavior -may be reported to the community leaders responsible for enforcement -at [info@tailscale.com](mailto:info@tailscale.com). All complaints -will be reviewed and investigated promptly and fairly. - -All community leaders are obligated to respect the privacy and -security of the reporter of any incident. - -## Enforcement Guidelines - -Community leaders will follow these Community Impact Guidelines in -determining the consequences for any action they deem in violation of -this Code of Conduct: - -### 1. Correction - -**Community Impact**: Use of inappropriate language or other behavior -deemed unprofessional or unwelcome in the community. - -**Consequence**: A private, written warning from community leaders, -providing clarity around the nature of the violation and an -explanation of why the behavior was inappropriate. A public apology -may be requested. - -### 2. Warning - -**Community Impact**: A violation through a single incident or series -of actions. - -**Consequence**: A warning with consequences for continued -behavior. No interaction with the people involved, including -unsolicited interaction with those enforcing the Code of Conduct, for -a specified period of time. This includes avoiding interactions in -community spaces as well as external channels like social -media. Violating these terms may lead to a temporary or permanent ban. - -### 3. Temporary Ban - -**Community Impact**: A serious violation of community standards, -including sustained inappropriate behavior. - -**Consequence**: A temporary ban from any sort of interaction or -public communication with the community for a specified period of -time. No public or private interaction with the people involved, -including unsolicited interaction with those enforcing the Code of -Conduct, is allowed during this period. Violating these terms may lead -to a permanent ban. - -### 4. Permanent Ban - -**Community Impact**: Demonstrating a pattern of violation of -community standards, including sustained inappropriate behavior, -harassment of an individual, or aggression toward or disparagement of -classes of individuals. - -**Consequence**: A permanent ban from any sort of public interaction -within the community. - -## Attribution - -This Code of Conduct is adapted from the [Contributor -Covenant][homepage], version 2.0, available at -https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. - -Community Impact Guidelines were inspired by [Mozilla's code of -conduct enforcement ladder](https://github.com/mozilla/diversity). - -[homepage]: https://www.contributor-covenant.org - -For answers to common questions about this code of conduct, see the -FAQ at https://www.contributor-covenant.org/faq. Translations are -available at https://www.contributor-covenant.org/translations. - +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation +in our community a harassment-free experience for everyone, regardless +of age, body size, visible or invisible disability, ethnicity, sex +characteristics, gender identity and expression, level of experience, +education, socio-economic status, nationality, personal appearance, +race, religion, or sexual identity and orientation. + +We pledge to act and interact in ways that contribute to an open, +welcoming, diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for +our community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our + mistakes, and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or + political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in + a professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our +standards of acceptable behavior and will take appropriate and fair +corrective action in response to any behavior that they deem +inappropriate, threatening, offensive, or harmful. + +Community leaders have the right and responsibility to remove, edit, +or reject comments, commits, code, wiki edits, issues, and other +contributions that are not aligned to this Code of Conduct, and will +communicate reasons for moderation decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also +applies when an individual is officially representing the community in +public spaces. Examples of representing our community include using an +official e-mail address, posting via an official social media account, +or acting as an appointed representative at an online or offline +event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior +may be reported to the community leaders responsible for enforcement +at [info@tailscale.com](mailto:info@tailscale.com). All complaints +will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and +security of the reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in +determining the consequences for any action they deem in violation of +this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior +deemed unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, +providing clarity around the nature of the violation and an +explanation of why the behavior was inappropriate. A public apology +may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued +behavior. No interaction with the people involved, including +unsolicited interaction with those enforcing the Code of Conduct, for +a specified period of time. This includes avoiding interactions in +community spaces as well as external channels like social +media. Violating these terms may lead to a temporary or permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, +including sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or +public communication with the community for a specified period of +time. No public or private interaction with the people involved, +including unsolicited interaction with those enforcing the Code of +Conduct, is allowed during this period. Violating these terms may lead +to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of +community standards, including sustained inappropriate behavior, +harassment of an individual, or aggression toward or disparagement of +classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction +within the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor +Covenant][homepage], version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of +conduct enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the +FAQ at https://www.contributor-covenant.org/faq. Translations are +available at https://www.contributor-covenant.org/translations. + diff --git a/LICENSE b/LICENSE index 3d511c30c..394db19e4 100644 --- a/LICENSE +++ b/LICENSE @@ -1,28 +1,28 @@ -BSD 3-Clause License - -Copyright (c) 2020 Tailscale Inc & AUTHORS. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +BSD 3-Clause License + +Copyright (c) 2020 Tailscale Inc & AUTHORS. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/PATENTS b/PATENTS index b001fb9c1..560a2b8f0 100644 --- a/PATENTS +++ b/PATENTS @@ -1,24 +1,24 @@ -Additional IP Rights Grant (Patents) - -"This implementation" means the copyrightable works distributed by -Tailscale Inc. as part of the Tailscale project. - -Tailscale Inc. hereby grants to You a perpetual, worldwide, -non-exclusive, no-charge, royalty-free, irrevocable (except as stated -in this section) patent license to make, have made, use, offer to -sell, sell, import, transfer and otherwise run, modify and propagate -the contents of this implementation of Tailscale, where such license -applies only to those patent claims, both currently owned or -controlled by Tailscale Inc. and acquired in the future, licensable -by Tailscale Inc. that are necessarily infringed by this -implementation of Tailscale. This grant does not include claims that -would be infringed only as a consequence of further modification of -this implementation. If you or your agent or exclusive licensee -institute or order or agree to the institution of patent litigation -against any entity (including a cross-claim or counterclaim in a -lawsuit) alleging that this implementation of Tailscale or any code -incorporated within this implementation of Tailscale constitutes -direct or contributory patent infringement, or inducement of patent -infringement, then any patent rights granted to you under this License -for this implementation of Tailscale shall terminate as of the date -such litigation is filed. +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Tailscale Inc. as part of the Tailscale project. + +Tailscale Inc. hereby grants to You a perpetual, worldwide, +non-exclusive, no-charge, royalty-free, irrevocable (except as stated +in this section) patent license to make, have made, use, offer to +sell, sell, import, transfer and otherwise run, modify and propagate +the contents of this implementation of Tailscale, where such license +applies only to those patent claims, both currently owned or +controlled by Tailscale Inc. and acquired in the future, licensable +by Tailscale Inc. that are necessarily infringed by this +implementation of Tailscale. This grant does not include claims that +would be infringed only as a consequence of further modification of +this implementation. If you or your agent or exclusive licensee +institute or order or agree to the institution of patent litigation +against any entity (including a cross-claim or counterclaim in a +lawsuit) alleging that this implementation of Tailscale or any code +incorporated within this implementation of Tailscale constitutes +direct or contributory patent infringement, or inducement of patent +infringement, then any patent rights granted to you under this License +for this implementation of Tailscale shall terminate as of the date +such litigation is filed. diff --git a/SECURITY.md b/SECURITY.md index e8cd9a326..26702b141 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,8 +1,8 @@ -# Security Policy - -## Reporting a Vulnerability - -You can report vulnerabilities privately to -[security@tailscale.com](mailto:security@tailscale.com). Tailscale -staff will triage the issue, and work with you on a coordinated -disclosure timeline. +# Security Policy + +## Reporting a Vulnerability + +You can report vulnerabilities privately to +[security@tailscale.com](mailto:security@tailscale.com). Tailscale +staff will triage the issue, and work with you on a coordinated +disclosure timeline. diff --git a/VERSION.txt b/VERSION.txt index 54227249d..79e15fd49 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -1.78.0 +1.77.0 diff --git a/atomicfile/atomicfile.go b/atomicfile/atomicfile.go index b95c7cbe1..5c18e85a8 100644 --- a/atomicfile/atomicfile.go +++ b/atomicfile/atomicfile.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package atomicfile contains code related to writing to filesystems -// atomically. -// -// This package should be considered internal; its API is not stable. -package atomicfile // import "tailscale.com/atomicfile" - -import ( - "fmt" - "os" - "path/filepath" - "runtime" -) - -// WriteFile writes data to filename+some suffix, then renames it into filename. -// The perm argument is ignored on Windows. If the target filename already -// exists but is not a regular file, WriteFile returns an error. -func WriteFile(filename string, data []byte, perm os.FileMode) (err error) { - fi, err := os.Stat(filename) - if err == nil && !fi.Mode().IsRegular() { - return fmt.Errorf("%s already exists and is not a regular file", filename) - } - f, err := os.CreateTemp(filepath.Dir(filename), filepath.Base(filename)+".tmp") - if err != nil { - return err - } - tmpName := f.Name() - defer func() { - if err != nil { - f.Close() - os.Remove(tmpName) - } - }() - if _, err := f.Write(data); err != nil { - return err - } - if runtime.GOOS != "windows" { - if err := f.Chmod(perm); err != nil { - return err - } - } - if err := f.Sync(); err != nil { - return err - } - if err := f.Close(); err != nil { - return err - } - return os.Rename(tmpName, filename) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package atomicfile contains code related to writing to filesystems +// atomically. +// +// This package should be considered internal; its API is not stable. +package atomicfile // import "tailscale.com/atomicfile" + +import ( + "fmt" + "os" + "path/filepath" + "runtime" +) + +// WriteFile writes data to filename+some suffix, then renames it into filename. +// The perm argument is ignored on Windows. If the target filename already +// exists but is not a regular file, WriteFile returns an error. +func WriteFile(filename string, data []byte, perm os.FileMode) (err error) { + fi, err := os.Stat(filename) + if err == nil && !fi.Mode().IsRegular() { + return fmt.Errorf("%s already exists and is not a regular file", filename) + } + f, err := os.CreateTemp(filepath.Dir(filename), filepath.Base(filename)+".tmp") + if err != nil { + return err + } + tmpName := f.Name() + defer func() { + if err != nil { + f.Close() + os.Remove(tmpName) + } + }() + if _, err := f.Write(data); err != nil { + return err + } + if runtime.GOOS != "windows" { + if err := f.Chmod(perm); err != nil { + return err + } + } + if err := f.Sync(); err != nil { + return err + } + if err := f.Close(); err != nil { + return err + } + return os.Rename(tmpName, filename) +} diff --git a/atomicfile/atomicfile_test.go b/atomicfile/atomicfile_test.go index b7a78765b..78c93e664 100644 --- a/atomicfile/atomicfile_test.go +++ b/atomicfile/atomicfile_test.go @@ -1,47 +1,47 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !js && !windows - -package atomicfile - -import ( - "net" - "os" - "path/filepath" - "runtime" - "strings" - "testing" -) - -func TestDoesNotOverwriteIrregularFiles(t *testing.T) { - // Per tailscale/tailscale#7658 as one example, almost any imagined use of - // atomicfile.Write should likely not attempt to overwrite an irregular file - // such as a device node, socket, or named pipe. - - const filename = "TestDoesNotOverwriteIrregularFiles" - var path string - // macOS private temp does not allow unix socket creation, but /tmp does. - if runtime.GOOS == "darwin" { - path = filepath.Join("/tmp", filename) - t.Cleanup(func() { os.Remove(path) }) - } else { - path = filepath.Join(t.TempDir(), filename) - } - - // The least troublesome thing to make that is not a file is a unix socket. - // Making a null device sadly requires root. - l, err := net.ListenUnix("unix", &net.UnixAddr{Name: path, Net: "unix"}) - if err != nil { - t.Fatal(err) - } - defer l.Close() - - err = WriteFile(path, []byte("hello"), 0644) - if err == nil { - t.Fatal("expected error, got nil") - } - if !strings.Contains(err.Error(), "is not a regular file") { - t.Fatalf("unexpected error: %v", err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !js && !windows + +package atomicfile + +import ( + "net" + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +func TestDoesNotOverwriteIrregularFiles(t *testing.T) { + // Per tailscale/tailscale#7658 as one example, almost any imagined use of + // atomicfile.Write should likely not attempt to overwrite an irregular file + // such as a device node, socket, or named pipe. + + const filename = "TestDoesNotOverwriteIrregularFiles" + var path string + // macOS private temp does not allow unix socket creation, but /tmp does. + if runtime.GOOS == "darwin" { + path = filepath.Join("/tmp", filename) + t.Cleanup(func() { os.Remove(path) }) + } else { + path = filepath.Join(t.TempDir(), filename) + } + + // The least troublesome thing to make that is not a file is a unix socket. + // Making a null device sadly requires root. + l, err := net.ListenUnix("unix", &net.UnixAddr{Name: path, Net: "unix"}) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + err = WriteFile(path, []byte("hello"), 0644) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "is not a regular file") { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/chirp/chirp.go b/chirp/chirp.go index 1b448f239..965387722 100644 --- a/chirp/chirp.go +++ b/chirp/chirp.go @@ -1,163 +1,163 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package chirp implements a client to communicate with the BIRD Internet -// Routing Daemon. -package chirp - -import ( - "bufio" - "fmt" - "net" - "strings" - "time" -) - -const ( - // Maximum amount of time we should wait when reading a response from BIRD. - responseTimeout = 10 * time.Second -) - -// New creates a BIRDClient. -func New(socket string) (*BIRDClient, error) { - return newWithTimeout(socket, responseTimeout) -} - -func newWithTimeout(socket string, timeout time.Duration) (_ *BIRDClient, err error) { - conn, err := net.Dial("unix", socket) - if err != nil { - return nil, fmt.Errorf("failed to connect to BIRD: %w", err) - } - defer func() { - if err != nil { - conn.Close() - } - }() - - b := &BIRDClient{ - socket: socket, - conn: conn, - scanner: bufio.NewScanner(conn), - timeNow: time.Now, - timeout: timeout, - } - // Read and discard the first line as that is the welcome message. - if _, err := b.readResponse(); err != nil { - return nil, err - } - return b, nil -} - -// BIRDClient handles communication with the BIRD Internet Routing Daemon. -type BIRDClient struct { - socket string - conn net.Conn - scanner *bufio.Scanner - timeNow func() time.Time - timeout time.Duration -} - -// Close closes the underlying connection to BIRD. -func (b *BIRDClient) Close() error { return b.conn.Close() } - -// DisableProtocol disables the provided protocol. -func (b *BIRDClient) DisableProtocol(protocol string) error { - out, err := b.exec("disable %s", protocol) - if err != nil { - return err - } - if strings.Contains(out, fmt.Sprintf("%s: already disabled", protocol)) { - return nil - } else if strings.Contains(out, fmt.Sprintf("%s: disabled", protocol)) { - return nil - } - return fmt.Errorf("failed to disable %s: %v", protocol, out) -} - -// EnableProtocol enables the provided protocol. -func (b *BIRDClient) EnableProtocol(protocol string) error { - out, err := b.exec("enable %s", protocol) - if err != nil { - return err - } - if strings.Contains(out, fmt.Sprintf("%s: already enabled", protocol)) { - return nil - } else if strings.Contains(out, fmt.Sprintf("%s: enabled", protocol)) { - return nil - } - return fmt.Errorf("failed to enable %s: %v", protocol, out) -} - -// BIRD CLI docs from https://bird.network.cz/?get_doc&v=20&f=prog-2.html#ss2.9 - -// Each session of the CLI consists of a sequence of request and replies, -// slightly resembling the FTP and SMTP protocols. -// Requests are commands encoded as a single line of text, -// replies are sequences of lines starting with a four-digit code -// followed by either a space (if it's the last line of the reply) or -// a minus sign (when the reply is going to continue with the next line), -// the rest of the line contains a textual message semantics of which depends on the numeric code. -// If a reply line has the same code as the previous one and it's a continuation line, -// the whole prefix can be replaced by a single white space character. -// -// Reply codes starting with 0 stand for ‘action successfully completed’ messages, -// 1 means ‘table entry’, 8 ‘runtime error’ and 9 ‘syntax error’. - -func (b *BIRDClient) exec(cmd string, args ...any) (string, error) { - if err := b.conn.SetWriteDeadline(b.timeNow().Add(b.timeout)); err != nil { - return "", err - } - if _, err := fmt.Fprintf(b.conn, cmd, args...); err != nil { - return "", err - } - if _, err := fmt.Fprintln(b.conn); err != nil { - return "", err - } - return b.readResponse() -} - -// hasResponseCode reports whether the provided byte slice is -// prefixed with a BIRD response code. -// Equivalent regex: `^\d{4}[ -]`. -func hasResponseCode(s []byte) bool { - if len(s) < 5 { - return false - } - for _, b := range s[:4] { - if '0' <= b && b <= '9' { - continue - } - return false - } - return s[4] == ' ' || s[4] == '-' -} - -func (b *BIRDClient) readResponse() (string, error) { - // Set the read timeout before we start reading anything. - if err := b.conn.SetReadDeadline(b.timeNow().Add(b.timeout)); err != nil { - return "", err - } - - var resp strings.Builder - var done bool - for !done { - if !b.scanner.Scan() { - if err := b.scanner.Err(); err != nil { - return "", err - } - - return "", fmt.Errorf("reading response from bird failed (EOF): %q", resp.String()) - } - out := b.scanner.Bytes() - if _, err := resp.Write(out); err != nil { - return "", err - } - if hasResponseCode(out) { - done = out[4] == ' ' - } - if !done { - resp.WriteRune('\n') - } - } - return resp.String(), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package chirp implements a client to communicate with the BIRD Internet +// Routing Daemon. +package chirp + +import ( + "bufio" + "fmt" + "net" + "strings" + "time" +) + +const ( + // Maximum amount of time we should wait when reading a response from BIRD. + responseTimeout = 10 * time.Second +) + +// New creates a BIRDClient. +func New(socket string) (*BIRDClient, error) { + return newWithTimeout(socket, responseTimeout) +} + +func newWithTimeout(socket string, timeout time.Duration) (_ *BIRDClient, err error) { + conn, err := net.Dial("unix", socket) + if err != nil { + return nil, fmt.Errorf("failed to connect to BIRD: %w", err) + } + defer func() { + if err != nil { + conn.Close() + } + }() + + b := &BIRDClient{ + socket: socket, + conn: conn, + scanner: bufio.NewScanner(conn), + timeNow: time.Now, + timeout: timeout, + } + // Read and discard the first line as that is the welcome message. + if _, err := b.readResponse(); err != nil { + return nil, err + } + return b, nil +} + +// BIRDClient handles communication with the BIRD Internet Routing Daemon. +type BIRDClient struct { + socket string + conn net.Conn + scanner *bufio.Scanner + timeNow func() time.Time + timeout time.Duration +} + +// Close closes the underlying connection to BIRD. +func (b *BIRDClient) Close() error { return b.conn.Close() } + +// DisableProtocol disables the provided protocol. +func (b *BIRDClient) DisableProtocol(protocol string) error { + out, err := b.exec("disable %s", protocol) + if err != nil { + return err + } + if strings.Contains(out, fmt.Sprintf("%s: already disabled", protocol)) { + return nil + } else if strings.Contains(out, fmt.Sprintf("%s: disabled", protocol)) { + return nil + } + return fmt.Errorf("failed to disable %s: %v", protocol, out) +} + +// EnableProtocol enables the provided protocol. +func (b *BIRDClient) EnableProtocol(protocol string) error { + out, err := b.exec("enable %s", protocol) + if err != nil { + return err + } + if strings.Contains(out, fmt.Sprintf("%s: already enabled", protocol)) { + return nil + } else if strings.Contains(out, fmt.Sprintf("%s: enabled", protocol)) { + return nil + } + return fmt.Errorf("failed to enable %s: %v", protocol, out) +} + +// BIRD CLI docs from https://bird.network.cz/?get_doc&v=20&f=prog-2.html#ss2.9 + +// Each session of the CLI consists of a sequence of request and replies, +// slightly resembling the FTP and SMTP protocols. +// Requests are commands encoded as a single line of text, +// replies are sequences of lines starting with a four-digit code +// followed by either a space (if it's the last line of the reply) or +// a minus sign (when the reply is going to continue with the next line), +// the rest of the line contains a textual message semantics of which depends on the numeric code. +// If a reply line has the same code as the previous one and it's a continuation line, +// the whole prefix can be replaced by a single white space character. +// +// Reply codes starting with 0 stand for ‘action successfully completed’ messages, +// 1 means ‘table entry’, 8 ‘runtime error’ and 9 ‘syntax error’. + +func (b *BIRDClient) exec(cmd string, args ...any) (string, error) { + if err := b.conn.SetWriteDeadline(b.timeNow().Add(b.timeout)); err != nil { + return "", err + } + if _, err := fmt.Fprintf(b.conn, cmd, args...); err != nil { + return "", err + } + if _, err := fmt.Fprintln(b.conn); err != nil { + return "", err + } + return b.readResponse() +} + +// hasResponseCode reports whether the provided byte slice is +// prefixed with a BIRD response code. +// Equivalent regex: `^\d{4}[ -]`. +func hasResponseCode(s []byte) bool { + if len(s) < 5 { + return false + } + for _, b := range s[:4] { + if '0' <= b && b <= '9' { + continue + } + return false + } + return s[4] == ' ' || s[4] == '-' +} + +func (b *BIRDClient) readResponse() (string, error) { + // Set the read timeout before we start reading anything. + if err := b.conn.SetReadDeadline(b.timeNow().Add(b.timeout)); err != nil { + return "", err + } + + var resp strings.Builder + var done bool + for !done { + if !b.scanner.Scan() { + if err := b.scanner.Err(); err != nil { + return "", err + } + + return "", fmt.Errorf("reading response from bird failed (EOF): %q", resp.String()) + } + out := b.scanner.Bytes() + if _, err := resp.Write(out); err != nil { + return "", err + } + if hasResponseCode(out) { + done = out[4] == ' ' + } + if !done { + resp.WriteRune('\n') + } + } + return resp.String(), nil +} diff --git a/chirp/chirp_test.go b/chirp/chirp_test.go index b8947a796..2549c163f 100644 --- a/chirp/chirp_test.go +++ b/chirp/chirp_test.go @@ -1,192 +1,192 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -package chirp - -import ( - "bufio" - "errors" - "fmt" - "net" - "os" - "path/filepath" - "strings" - "sync" - "testing" - "time" -) - -type fakeBIRD struct { - net.Listener - protocolsEnabled map[string]bool - sock string -} - -func newFakeBIRD(t *testing.T, protocols ...string) *fakeBIRD { - sock := filepath.Join(t.TempDir(), "sock") - l, err := net.Listen("unix", sock) - if err != nil { - t.Fatal(err) - } - pe := make(map[string]bool) - for _, p := range protocols { - pe[p] = false - } - return &fakeBIRD{ - Listener: l, - protocolsEnabled: pe, - sock: sock, - } -} - -func (fb *fakeBIRD) listen() error { - for { - c, err := fb.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - return nil - } - return err - } - go fb.handle(c) - } -} - -func (fb *fakeBIRD) handle(c net.Conn) { - fmt.Fprintln(c, "0001 BIRD 2.0.8 ready.") - sc := bufio.NewScanner(c) - for sc.Scan() { - cmd := sc.Text() - args := strings.Split(cmd, " ") - switch args[0] { - case "enable": - en, ok := fb.protocolsEnabled[args[1]] - if !ok { - fmt.Fprintln(c, "9001 syntax error, unexpected CF_SYM_UNDEFINED, expecting CF_SYM_KNOWN or TEXT or ALL") - } else if en { - fmt.Fprintf(c, "0010-%s: already enabled\n", args[1]) - } else { - fmt.Fprintf(c, "0011-%s: enabled\n", args[1]) - } - fmt.Fprintln(c, "0000 ") - fb.protocolsEnabled[args[1]] = true - case "disable": - en, ok := fb.protocolsEnabled[args[1]] - if !ok { - fmt.Fprintln(c, "9001 syntax error, unexpected CF_SYM_UNDEFINED, expecting CF_SYM_KNOWN or TEXT or ALL") - } else if !en { - fmt.Fprintf(c, "0008-%s: already disabled\n", args[1]) - } else { - fmt.Fprintf(c, "0009-%s: disabled\n", args[1]) - } - fmt.Fprintln(c, "0000 ") - fb.protocolsEnabled[args[1]] = false - } - } -} - -func TestChirp(t *testing.T) { - fb := newFakeBIRD(t, "tailscale") - defer fb.Close() - go fb.listen() - c, err := New(fb.sock) - if err != nil { - t.Fatal(err) - } - if err := c.EnableProtocol("tailscale"); err != nil { - t.Fatal(err) - } - if err := c.EnableProtocol("tailscale"); err != nil { - t.Fatal(err) - } - if err := c.DisableProtocol("tailscale"); err != nil { - t.Fatal(err) - } - if err := c.DisableProtocol("tailscale"); err != nil { - t.Fatal(err) - } - if err := c.EnableProtocol("rando"); err == nil { - t.Fatalf("enabling %q succeeded", "rando") - } - if err := c.DisableProtocol("rando"); err == nil { - t.Fatalf("disabling %q succeeded", "rando") - } -} - -type hangingListener struct { - net.Listener - t *testing.T - done chan struct{} - wg sync.WaitGroup - sock string -} - -func newHangingListener(t *testing.T) *hangingListener { - sock := filepath.Join(t.TempDir(), "sock") - l, err := net.Listen("unix", sock) - if err != nil { - t.Fatal(err) - } - return &hangingListener{ - Listener: l, - t: t, - done: make(chan struct{}), - sock: sock, - } -} - -func (hl *hangingListener) Stop() { - hl.Close() - close(hl.done) - hl.wg.Wait() -} - -func (hl *hangingListener) listen() error { - for { - c, err := hl.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - return nil - } - return err - } - hl.wg.Add(1) - go hl.handle(c) - } -} - -func (hl *hangingListener) handle(c net.Conn) { - defer hl.wg.Done() - - // Write our fake first line of response so that we get into the read loop - fmt.Fprintln(c, "0001 BIRD 2.0.8 ready.") - - ticker := time.NewTicker(2 * time.Second) - defer ticker.Stop() - for { - select { - case <-ticker.C: - hl.t.Logf("connection still hanging") - case <-hl.done: - return - } - } -} - -func TestChirpTimeout(t *testing.T) { - fb := newHangingListener(t) - defer fb.Stop() - go fb.listen() - - c, err := newWithTimeout(fb.sock, 500*time.Millisecond) - if err != nil { - t.Fatal(err) - } - - err = c.EnableProtocol("tailscale") - if err == nil { - t.Fatal("got err=nil, want timeout") - } - if !os.IsTimeout(err) { - t.Fatalf("got err=%v, want os.IsTimeout(err)=true", err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +package chirp + +import ( + "bufio" + "errors" + "fmt" + "net" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" +) + +type fakeBIRD struct { + net.Listener + protocolsEnabled map[string]bool + sock string +} + +func newFakeBIRD(t *testing.T, protocols ...string) *fakeBIRD { + sock := filepath.Join(t.TempDir(), "sock") + l, err := net.Listen("unix", sock) + if err != nil { + t.Fatal(err) + } + pe := make(map[string]bool) + for _, p := range protocols { + pe[p] = false + } + return &fakeBIRD{ + Listener: l, + protocolsEnabled: pe, + sock: sock, + } +} + +func (fb *fakeBIRD) listen() error { + for { + c, err := fb.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + } + go fb.handle(c) + } +} + +func (fb *fakeBIRD) handle(c net.Conn) { + fmt.Fprintln(c, "0001 BIRD 2.0.8 ready.") + sc := bufio.NewScanner(c) + for sc.Scan() { + cmd := sc.Text() + args := strings.Split(cmd, " ") + switch args[0] { + case "enable": + en, ok := fb.protocolsEnabled[args[1]] + if !ok { + fmt.Fprintln(c, "9001 syntax error, unexpected CF_SYM_UNDEFINED, expecting CF_SYM_KNOWN or TEXT or ALL") + } else if en { + fmt.Fprintf(c, "0010-%s: already enabled\n", args[1]) + } else { + fmt.Fprintf(c, "0011-%s: enabled\n", args[1]) + } + fmt.Fprintln(c, "0000 ") + fb.protocolsEnabled[args[1]] = true + case "disable": + en, ok := fb.protocolsEnabled[args[1]] + if !ok { + fmt.Fprintln(c, "9001 syntax error, unexpected CF_SYM_UNDEFINED, expecting CF_SYM_KNOWN or TEXT or ALL") + } else if !en { + fmt.Fprintf(c, "0008-%s: already disabled\n", args[1]) + } else { + fmt.Fprintf(c, "0009-%s: disabled\n", args[1]) + } + fmt.Fprintln(c, "0000 ") + fb.protocolsEnabled[args[1]] = false + } + } +} + +func TestChirp(t *testing.T) { + fb := newFakeBIRD(t, "tailscale") + defer fb.Close() + go fb.listen() + c, err := New(fb.sock) + if err != nil { + t.Fatal(err) + } + if err := c.EnableProtocol("tailscale"); err != nil { + t.Fatal(err) + } + if err := c.EnableProtocol("tailscale"); err != nil { + t.Fatal(err) + } + if err := c.DisableProtocol("tailscale"); err != nil { + t.Fatal(err) + } + if err := c.DisableProtocol("tailscale"); err != nil { + t.Fatal(err) + } + if err := c.EnableProtocol("rando"); err == nil { + t.Fatalf("enabling %q succeeded", "rando") + } + if err := c.DisableProtocol("rando"); err == nil { + t.Fatalf("disabling %q succeeded", "rando") + } +} + +type hangingListener struct { + net.Listener + t *testing.T + done chan struct{} + wg sync.WaitGroup + sock string +} + +func newHangingListener(t *testing.T) *hangingListener { + sock := filepath.Join(t.TempDir(), "sock") + l, err := net.Listen("unix", sock) + if err != nil { + t.Fatal(err) + } + return &hangingListener{ + Listener: l, + t: t, + done: make(chan struct{}), + sock: sock, + } +} + +func (hl *hangingListener) Stop() { + hl.Close() + close(hl.done) + hl.wg.Wait() +} + +func (hl *hangingListener) listen() error { + for { + c, err := hl.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + } + hl.wg.Add(1) + go hl.handle(c) + } +} + +func (hl *hangingListener) handle(c net.Conn) { + defer hl.wg.Done() + + // Write our fake first line of response so that we get into the read loop + fmt.Fprintln(c, "0001 BIRD 2.0.8 ready.") + + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + hl.t.Logf("connection still hanging") + case <-hl.done: + return + } + } +} + +func TestChirpTimeout(t *testing.T) { + fb := newHangingListener(t) + defer fb.Stop() + go fb.listen() + + c, err := newWithTimeout(fb.sock, 500*time.Millisecond) + if err != nil { + t.Fatal(err) + } + + err = c.EnableProtocol("tailscale") + if err == nil { + t.Fatal("got err=nil, want timeout") + } + if !os.IsTimeout(err) { + t.Fatalf("got err=%v, want os.IsTimeout(err)=true", err) + } +} diff --git a/client/tailscale/apitype/controltype.go b/client/tailscale/apitype/controltype.go index a9a76065f..9a623be31 100644 --- a/client/tailscale/apitype/controltype.go +++ b/client/tailscale/apitype/controltype.go @@ -1,19 +1,19 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package apitype - -type DNSConfig struct { - Resolvers []DNSResolver `json:"resolvers"` - FallbackResolvers []DNSResolver `json:"fallbackResolvers"` - Routes map[string][]DNSResolver `json:"routes"` - Domains []string `json:"domains"` - Nameservers []string `json:"nameservers"` - Proxied bool `json:"proxied"` - TempCorpIssue13969 string `json:"TempCorpIssue13969,omitempty"` -} - -type DNSResolver struct { - Addr string `json:"addr"` - BootstrapResolution []string `json:"bootstrapResolution,omitempty"` -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package apitype + +type DNSConfig struct { + Resolvers []DNSResolver `json:"resolvers"` + FallbackResolvers []DNSResolver `json:"fallbackResolvers"` + Routes map[string][]DNSResolver `json:"routes"` + Domains []string `json:"domains"` + Nameservers []string `json:"nameservers"` + Proxied bool `json:"proxied"` + TempCorpIssue13969 string `json:"TempCorpIssue13969,omitempty"` +} + +type DNSResolver struct { + Addr string `json:"addr"` + BootstrapResolution []string `json:"bootstrapResolution,omitempty"` +} diff --git a/client/tailscale/dns.go b/client/tailscale/dns.go index 12b9e15c8..f198742b3 100644 --- a/client/tailscale/dns.go +++ b/client/tailscale/dns.go @@ -1,233 +1,233 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 - -package tailscale - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - - "tailscale.com/client/tailscale/apitype" -) - -// DNSNameServers is returned when retrieving the list of nameservers. -// It is also the structure provided when setting nameservers. -type DNSNameServers struct { - DNS []string `json:"dns"` // DNS name servers -} - -// DNSNameServersPostResponse is returned when setting the list of DNS nameservers. -// -// It includes the MagicDNS status since nameservers changes may affect MagicDNS. -type DNSNameServersPostResponse struct { - DNS []string `json:"dns"` // DNS name servers - MagicDNS bool `json:"magicDNS"` // whether MagicDNS is active for this tailnet (enabled + has fallback nameservers) -} - -// DNSSearchpaths is the list of search paths for a given domain. -type DNSSearchPaths struct { - SearchPaths []string `json:"searchPaths"` // DNS search paths -} - -// DNSPreferences is the preferences set for a given tailnet. -// -// It includes MagicDNS which can be turned on or off. To enable MagicDNS, -// there must be at least one nameserver. When all nameservers are removed, -// MagicDNS is disabled. -type DNSPreferences struct { - MagicDNS bool `json:"magicDNS"` // whether MagicDNS is active for this tailnet (enabled + has fallback nameservers) -} - -func (c *Client) dnsGETRequest(ctx context.Context, endpoint string) ([]byte, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/dns/%s", c.baseURL(), c.tailnet, endpoint) - req, err := http.NewRequestWithContext(ctx, "GET", path, nil) - if err != nil { - return nil, err - } - b, resp, err := c.sendRequest(req) - if err != nil { - return nil, err - } - - // If status code was not successful, return the error. - // TODO: Change the check for the StatusCode to include other 2XX success codes. - if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) - } - - return b, nil -} - -func (c *Client) dnsPOSTRequest(ctx context.Context, endpoint string, postData any) ([]byte, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/dns/%s", c.baseURL(), c.tailnet, endpoint) - data, err := json.Marshal(&postData) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewBuffer(data)) - req.Header.Set("Content-Type", "application/json") - if err != nil { - return nil, err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return nil, err - } - - // If status code was not successful, return the error. - // TODO: Change the check for the StatusCode to include other 2XX success codes. - if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) - } - - return b, nil -} - -// DNSConfig retrieves the DNSConfig settings for a domain. -func (c *Client) DNSConfig(ctx context.Context) (cfg *apitype.DNSConfig, err error) { - // Format return errors to be descriptive. - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.DNSConfig: %w", err) - } - }() - b, err := c.dnsGETRequest(ctx, "config") - if err != nil { - return nil, err - } - var dnsResp apitype.DNSConfig - err = json.Unmarshal(b, &dnsResp) - return &dnsResp, err -} - -func (c *Client) SetDNSConfig(ctx context.Context, cfg apitype.DNSConfig) (resp *apitype.DNSConfig, err error) { - // Format return errors to be descriptive. - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.SetDNSConfig: %w", err) - } - }() - var dnsResp apitype.DNSConfig - b, err := c.dnsPOSTRequest(ctx, "config", cfg) - if err != nil { - return nil, err - } - err = json.Unmarshal(b, &dnsResp) - return &dnsResp, err -} - -// NameServers retrieves the list of nameservers set for a domain. -func (c *Client) NameServers(ctx context.Context) (nameservers []string, err error) { - // Format return errors to be descriptive. - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.NameServers: %w", err) - } - }() - b, err := c.dnsGETRequest(ctx, "nameservers") - if err != nil { - return nil, err - } - var dnsResp DNSNameServers - err = json.Unmarshal(b, &dnsResp) - return dnsResp.DNS, err -} - -// SetNameServers sets the list of nameservers for a tailnet to the list provided -// by the user. -// -// It returns the new list of nameservers and the MagicDNS status in case it was -// affected by the change. For example, removing all nameservers will turn off -// MagicDNS. -func (c *Client) SetNameServers(ctx context.Context, nameservers []string) (dnsResp *DNSNameServersPostResponse, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.SetNameServers: %w", err) - } - }() - dnsReq := DNSNameServers{DNS: nameservers} - b, err := c.dnsPOSTRequest(ctx, "nameservers", dnsReq) - if err != nil { - return nil, err - } - err = json.Unmarshal(b, &dnsResp) - return dnsResp, err -} - -// DNSPreferences retrieves the DNS preferences set for a tailnet. -// -// It returns the status of MagicDNS. -func (c *Client) DNSPreferences(ctx context.Context) (dnsResp *DNSPreferences, err error) { - // Format return errors to be descriptive. - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.DNSPreferences: %w", err) - } - }() - b, err := c.dnsGETRequest(ctx, "preferences") - if err != nil { - return nil, err - } - err = json.Unmarshal(b, &dnsResp) - return dnsResp, err -} - -// SetDNSPreferences sets the DNS preferences for a tailnet. -// -// MagicDNS can only be enabled when there is at least one nameserver provided. -// When all nameservers are removed, MagicDNS is disabled and will stay disabled, -// unless explicitly enabled by a user again. -func (c *Client) SetDNSPreferences(ctx context.Context, magicDNS bool) (dnsResp *DNSPreferences, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.SetDNSPreferences: %w", err) - } - }() - dnsReq := DNSPreferences{MagicDNS: magicDNS} - b, err := c.dnsPOSTRequest(ctx, "preferences", dnsReq) - if err != nil { - return - } - err = json.Unmarshal(b, &dnsResp) - return dnsResp, err -} - -// SearchPaths retrieves the list of searchpaths set for a tailnet. -func (c *Client) SearchPaths(ctx context.Context) (searchpaths []string, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.SearchPaths: %w", err) - } - }() - b, err := c.dnsGETRequest(ctx, "searchpaths") - if err != nil { - return nil, err - } - var dnsResp *DNSSearchPaths - err = json.Unmarshal(b, &dnsResp) - return dnsResp.SearchPaths, err -} - -// SetSearchPaths sets the list of searchpaths for a tailnet. -func (c *Client) SetSearchPaths(ctx context.Context, searchpaths []string) (newSearchPaths []string, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.SetSearchPaths: %w", err) - } - }() - dnsReq := DNSSearchPaths{SearchPaths: searchpaths} - b, err := c.dnsPOSTRequest(ctx, "searchpaths", dnsReq) - if err != nil { - return nil, err - } - var dnsResp DNSSearchPaths - err = json.Unmarshal(b, &dnsResp) - return dnsResp.SearchPaths, err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +package tailscale + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + + "tailscale.com/client/tailscale/apitype" +) + +// DNSNameServers is returned when retrieving the list of nameservers. +// It is also the structure provided when setting nameservers. +type DNSNameServers struct { + DNS []string `json:"dns"` // DNS name servers +} + +// DNSNameServersPostResponse is returned when setting the list of DNS nameservers. +// +// It includes the MagicDNS status since nameservers changes may affect MagicDNS. +type DNSNameServersPostResponse struct { + DNS []string `json:"dns"` // DNS name servers + MagicDNS bool `json:"magicDNS"` // whether MagicDNS is active for this tailnet (enabled + has fallback nameservers) +} + +// DNSSearchpaths is the list of search paths for a given domain. +type DNSSearchPaths struct { + SearchPaths []string `json:"searchPaths"` // DNS search paths +} + +// DNSPreferences is the preferences set for a given tailnet. +// +// It includes MagicDNS which can be turned on or off. To enable MagicDNS, +// there must be at least one nameserver. When all nameservers are removed, +// MagicDNS is disabled. +type DNSPreferences struct { + MagicDNS bool `json:"magicDNS"` // whether MagicDNS is active for this tailnet (enabled + has fallback nameservers) +} + +func (c *Client) dnsGETRequest(ctx context.Context, endpoint string) ([]byte, error) { + path := fmt.Sprintf("%s/api/v2/tailnet/%s/dns/%s", c.baseURL(), c.tailnet, endpoint) + req, err := http.NewRequestWithContext(ctx, "GET", path, nil) + if err != nil { + return nil, err + } + b, resp, err := c.sendRequest(req) + if err != nil { + return nil, err + } + + // If status code was not successful, return the error. + // TODO: Change the check for the StatusCode to include other 2XX success codes. + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(b, resp) + } + + return b, nil +} + +func (c *Client) dnsPOSTRequest(ctx context.Context, endpoint string, postData any) ([]byte, error) { + path := fmt.Sprintf("%s/api/v2/tailnet/%s/dns/%s", c.baseURL(), c.tailnet, endpoint) + data, err := json.Marshal(&postData) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewBuffer(data)) + req.Header.Set("Content-Type", "application/json") + if err != nil { + return nil, err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return nil, err + } + + // If status code was not successful, return the error. + // TODO: Change the check for the StatusCode to include other 2XX success codes. + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(b, resp) + } + + return b, nil +} + +// DNSConfig retrieves the DNSConfig settings for a domain. +func (c *Client) DNSConfig(ctx context.Context) (cfg *apitype.DNSConfig, err error) { + // Format return errors to be descriptive. + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.DNSConfig: %w", err) + } + }() + b, err := c.dnsGETRequest(ctx, "config") + if err != nil { + return nil, err + } + var dnsResp apitype.DNSConfig + err = json.Unmarshal(b, &dnsResp) + return &dnsResp, err +} + +func (c *Client) SetDNSConfig(ctx context.Context, cfg apitype.DNSConfig) (resp *apitype.DNSConfig, err error) { + // Format return errors to be descriptive. + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.SetDNSConfig: %w", err) + } + }() + var dnsResp apitype.DNSConfig + b, err := c.dnsPOSTRequest(ctx, "config", cfg) + if err != nil { + return nil, err + } + err = json.Unmarshal(b, &dnsResp) + return &dnsResp, err +} + +// NameServers retrieves the list of nameservers set for a domain. +func (c *Client) NameServers(ctx context.Context) (nameservers []string, err error) { + // Format return errors to be descriptive. + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.NameServers: %w", err) + } + }() + b, err := c.dnsGETRequest(ctx, "nameservers") + if err != nil { + return nil, err + } + var dnsResp DNSNameServers + err = json.Unmarshal(b, &dnsResp) + return dnsResp.DNS, err +} + +// SetNameServers sets the list of nameservers for a tailnet to the list provided +// by the user. +// +// It returns the new list of nameservers and the MagicDNS status in case it was +// affected by the change. For example, removing all nameservers will turn off +// MagicDNS. +func (c *Client) SetNameServers(ctx context.Context, nameservers []string) (dnsResp *DNSNameServersPostResponse, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.SetNameServers: %w", err) + } + }() + dnsReq := DNSNameServers{DNS: nameservers} + b, err := c.dnsPOSTRequest(ctx, "nameservers", dnsReq) + if err != nil { + return nil, err + } + err = json.Unmarshal(b, &dnsResp) + return dnsResp, err +} + +// DNSPreferences retrieves the DNS preferences set for a tailnet. +// +// It returns the status of MagicDNS. +func (c *Client) DNSPreferences(ctx context.Context) (dnsResp *DNSPreferences, err error) { + // Format return errors to be descriptive. + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.DNSPreferences: %w", err) + } + }() + b, err := c.dnsGETRequest(ctx, "preferences") + if err != nil { + return nil, err + } + err = json.Unmarshal(b, &dnsResp) + return dnsResp, err +} + +// SetDNSPreferences sets the DNS preferences for a tailnet. +// +// MagicDNS can only be enabled when there is at least one nameserver provided. +// When all nameservers are removed, MagicDNS is disabled and will stay disabled, +// unless explicitly enabled by a user again. +func (c *Client) SetDNSPreferences(ctx context.Context, magicDNS bool) (dnsResp *DNSPreferences, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.SetDNSPreferences: %w", err) + } + }() + dnsReq := DNSPreferences{MagicDNS: magicDNS} + b, err := c.dnsPOSTRequest(ctx, "preferences", dnsReq) + if err != nil { + return + } + err = json.Unmarshal(b, &dnsResp) + return dnsResp, err +} + +// SearchPaths retrieves the list of searchpaths set for a tailnet. +func (c *Client) SearchPaths(ctx context.Context) (searchpaths []string, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.SearchPaths: %w", err) + } + }() + b, err := c.dnsGETRequest(ctx, "searchpaths") + if err != nil { + return nil, err + } + var dnsResp *DNSSearchPaths + err = json.Unmarshal(b, &dnsResp) + return dnsResp.SearchPaths, err +} + +// SetSearchPaths sets the list of searchpaths for a tailnet. +func (c *Client) SetSearchPaths(ctx context.Context, searchpaths []string) (newSearchPaths []string, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.SetSearchPaths: %w", err) + } + }() + dnsReq := DNSSearchPaths{SearchPaths: searchpaths} + b, err := c.dnsPOSTRequest(ctx, "searchpaths", dnsReq) + if err != nil { + return nil, err + } + var dnsResp DNSSearchPaths + err = json.Unmarshal(b, &dnsResp) + return dnsResp.SearchPaths, err +} diff --git a/client/tailscale/example/servetls/servetls.go b/client/tailscale/example/servetls/servetls.go index e426cbea2..f48e90d16 100644 --- a/client/tailscale/example/servetls/servetls.go +++ b/client/tailscale/example/servetls/servetls.go @@ -1,28 +1,28 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The servetls program shows how to run an HTTPS server -// using a Tailscale cert via LetsEncrypt. -package main - -import ( - "crypto/tls" - "io" - "log" - "net/http" - - "tailscale.com/client/tailscale" -) - -func main() { - s := &http.Server{ - TLSConfig: &tls.Config{ - GetCertificate: tailscale.GetCertificate, - }, - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, "

Hello from Tailscale!

It works.") - }), - } - log.Printf("Running TLS server on :443 ...") - log.Fatal(s.ListenAndServeTLS("", "")) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The servetls program shows how to run an HTTPS server +// using a Tailscale cert via LetsEncrypt. +package main + +import ( + "crypto/tls" + "io" + "log" + "net/http" + + "tailscale.com/client/tailscale" +) + +func main() { + s := &http.Server{ + TLSConfig: &tls.Config{ + GetCertificate: tailscale.GetCertificate, + }, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "

Hello from Tailscale!

It works.") + }), + } + log.Printf("Running TLS server on :443 ...") + log.Fatal(s.ListenAndServeTLS("", "")) +} diff --git a/client/tailscale/keys.go b/client/tailscale/keys.go index ae5f721b7..84bcdfae6 100644 --- a/client/tailscale/keys.go +++ b/client/tailscale/keys.go @@ -1,166 +1,166 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tailscale - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "time" -) - -// Key represents a Tailscale API or auth key. -type Key struct { - ID string `json:"id"` - Created time.Time `json:"created"` - Expires time.Time `json:"expires"` - Capabilities KeyCapabilities `json:"capabilities"` -} - -// KeyCapabilities are the capabilities of a Key. -type KeyCapabilities struct { - Devices KeyDeviceCapabilities `json:"devices,omitempty"` -} - -// KeyDeviceCapabilities are the device-related capabilities of a Key. -type KeyDeviceCapabilities struct { - Create KeyDeviceCreateCapabilities `json:"create"` -} - -// KeyDeviceCreateCapabilities are the device creation capabilities of a Key. -type KeyDeviceCreateCapabilities struct { - Reusable bool `json:"reusable"` - Ephemeral bool `json:"ephemeral"` - Preauthorized bool `json:"preauthorized"` - Tags []string `json:"tags,omitempty"` -} - -// Keys returns the list of keys for the current user. -func (c *Client) Keys(ctx context.Context) ([]string, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys", c.baseURL(), c.tailnet) - req, err := http.NewRequestWithContext(ctx, "GET", path, nil) - if err != nil { - return nil, err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return nil, err - } - if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) - } - - var keys struct { - Keys []*Key `json:"keys"` - } - if err := json.Unmarshal(b, &keys); err != nil { - return nil, err - } - ret := make([]string, 0, len(keys.Keys)) - for _, k := range keys.Keys { - ret = append(ret, k.ID) - } - return ret, nil -} - -// CreateKey creates a new key for the current user. Currently, only auth keys -// can be created. It returns the secret key itself, which cannot be retrieved again -// later, and the key metadata. -// -// To create a key with a specific expiry, use CreateKeyWithExpiry. -func (c *Client) CreateKey(ctx context.Context, caps KeyCapabilities) (keySecret string, keyMeta *Key, _ error) { - return c.CreateKeyWithExpiry(ctx, caps, 0) -} - -// CreateKeyWithExpiry is like CreateKey, but allows specifying a expiration time. -// -// The time is truncated to a whole number of seconds. If zero, that means no expiration. -func (c *Client) CreateKeyWithExpiry(ctx context.Context, caps KeyCapabilities, expiry time.Duration) (keySecret string, keyMeta *Key, _ error) { - - // convert expirySeconds to an int64 (seconds) - expirySeconds := int64(expiry.Seconds()) - if expirySeconds < 0 { - return "", nil, fmt.Errorf("expiry must be positive") - } - if expirySeconds == 0 && expiry != 0 { - return "", nil, fmt.Errorf("non-zero expiry must be at least one second") - } - - keyRequest := struct { - Capabilities KeyCapabilities `json:"capabilities"` - ExpirySeconds int64 `json:"expirySeconds,omitempty"` - }{caps, int64(expirySeconds)} - bs, err := json.Marshal(keyRequest) - if err != nil { - return "", nil, err - } - - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys", c.baseURL(), c.tailnet) - req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewReader(bs)) - if err != nil { - return "", nil, err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return "", nil, err - } - if resp.StatusCode != http.StatusOK { - return "", nil, handleErrorResponse(b, resp) - } - - var key struct { - Key - Secret string `json:"key"` - } - if err := json.Unmarshal(b, &key); err != nil { - return "", nil, err - } - return key.Secret, &key.Key, nil -} - -// Key returns the metadata for the given key ID. Currently, capabilities are -// only returned for auth keys, API keys only return general metadata. -func (c *Client) Key(ctx context.Context, id string) (*Key, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys/%s", c.baseURL(), c.tailnet, id) - req, err := http.NewRequestWithContext(ctx, "GET", path, nil) - if err != nil { - return nil, err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return nil, err - } - if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) - } - - var key Key - if err := json.Unmarshal(b, &key); err != nil { - return nil, err - } - return &key, nil -} - -// DeleteKey deletes the key with the given ID. -func (c *Client) DeleteKey(ctx context.Context, id string) error { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys/%s", c.baseURL(), c.tailnet, id) - req, err := http.NewRequestWithContext(ctx, "DELETE", path, nil) - if err != nil { - return err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return err - } - if resp.StatusCode != http.StatusOK { - return handleErrorResponse(b, resp) - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailscale + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "time" +) + +// Key represents a Tailscale API or auth key. +type Key struct { + ID string `json:"id"` + Created time.Time `json:"created"` + Expires time.Time `json:"expires"` + Capabilities KeyCapabilities `json:"capabilities"` +} + +// KeyCapabilities are the capabilities of a Key. +type KeyCapabilities struct { + Devices KeyDeviceCapabilities `json:"devices,omitempty"` +} + +// KeyDeviceCapabilities are the device-related capabilities of a Key. +type KeyDeviceCapabilities struct { + Create KeyDeviceCreateCapabilities `json:"create"` +} + +// KeyDeviceCreateCapabilities are the device creation capabilities of a Key. +type KeyDeviceCreateCapabilities struct { + Reusable bool `json:"reusable"` + Ephemeral bool `json:"ephemeral"` + Preauthorized bool `json:"preauthorized"` + Tags []string `json:"tags,omitempty"` +} + +// Keys returns the list of keys for the current user. +func (c *Client) Keys(ctx context.Context) ([]string, error) { + path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys", c.baseURL(), c.tailnet) + req, err := http.NewRequestWithContext(ctx, "GET", path, nil) + if err != nil { + return nil, err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(b, resp) + } + + var keys struct { + Keys []*Key `json:"keys"` + } + if err := json.Unmarshal(b, &keys); err != nil { + return nil, err + } + ret := make([]string, 0, len(keys.Keys)) + for _, k := range keys.Keys { + ret = append(ret, k.ID) + } + return ret, nil +} + +// CreateKey creates a new key for the current user. Currently, only auth keys +// can be created. It returns the secret key itself, which cannot be retrieved again +// later, and the key metadata. +// +// To create a key with a specific expiry, use CreateKeyWithExpiry. +func (c *Client) CreateKey(ctx context.Context, caps KeyCapabilities) (keySecret string, keyMeta *Key, _ error) { + return c.CreateKeyWithExpiry(ctx, caps, 0) +} + +// CreateKeyWithExpiry is like CreateKey, but allows specifying a expiration time. +// +// The time is truncated to a whole number of seconds. If zero, that means no expiration. +func (c *Client) CreateKeyWithExpiry(ctx context.Context, caps KeyCapabilities, expiry time.Duration) (keySecret string, keyMeta *Key, _ error) { + + // convert expirySeconds to an int64 (seconds) + expirySeconds := int64(expiry.Seconds()) + if expirySeconds < 0 { + return "", nil, fmt.Errorf("expiry must be positive") + } + if expirySeconds == 0 && expiry != 0 { + return "", nil, fmt.Errorf("non-zero expiry must be at least one second") + } + + keyRequest := struct { + Capabilities KeyCapabilities `json:"capabilities"` + ExpirySeconds int64 `json:"expirySeconds,omitempty"` + }{caps, int64(expirySeconds)} + bs, err := json.Marshal(keyRequest) + if err != nil { + return "", nil, err + } + + path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys", c.baseURL(), c.tailnet) + req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewReader(bs)) + if err != nil { + return "", nil, err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return "", nil, err + } + if resp.StatusCode != http.StatusOK { + return "", nil, handleErrorResponse(b, resp) + } + + var key struct { + Key + Secret string `json:"key"` + } + if err := json.Unmarshal(b, &key); err != nil { + return "", nil, err + } + return key.Secret, &key.Key, nil +} + +// Key returns the metadata for the given key ID. Currently, capabilities are +// only returned for auth keys, API keys only return general metadata. +func (c *Client) Key(ctx context.Context, id string) (*Key, error) { + path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys/%s", c.baseURL(), c.tailnet, id) + req, err := http.NewRequestWithContext(ctx, "GET", path, nil) + if err != nil { + return nil, err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(b, resp) + } + + var key Key + if err := json.Unmarshal(b, &key); err != nil { + return nil, err + } + return &key, nil +} + +// DeleteKey deletes the key with the given ID. +func (c *Client) DeleteKey(ctx context.Context, id string) error { + path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys/%s", c.baseURL(), c.tailnet, id) + req, err := http.NewRequestWithContext(ctx, "DELETE", path, nil) + if err != nil { + return err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return err + } + if resp.StatusCode != http.StatusOK { + return handleErrorResponse(b, resp) + } + return nil +} diff --git a/client/tailscale/routes.go b/client/tailscale/routes.go index 41415d1b4..5912fc46c 100644 --- a/client/tailscale/routes.go +++ b/client/tailscale/routes.go @@ -1,95 +1,95 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 - -package tailscale - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "net/netip" -) - -// Routes contains the lists of subnet routes that are currently advertised by a device, -// as well as the subnets that are enabled to be routed by the device. -type Routes struct { - AdvertisedRoutes []netip.Prefix `json:"advertisedRoutes"` - EnabledRoutes []netip.Prefix `json:"enabledRoutes"` -} - -// Routes retrieves the list of subnet routes that have been enabled for a device. -// The routes that are returned are not necessarily advertised by the device, -// they have only been preapproved. -func (c *Client) Routes(ctx context.Context, deviceID string) (routes *Routes, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.Routes: %w", err) - } - }() - - path := fmt.Sprintf("%s/api/v2/device/%s/routes", c.baseURL(), deviceID) - req, err := http.NewRequestWithContext(ctx, "GET", path, nil) - if err != nil { - return nil, err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return nil, err - } - // If status code was not successful, return the error. - // TODO: Change the check for the StatusCode to include other 2XX success codes. - if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) - } - - var sr Routes - err = json.Unmarshal(b, &sr) - return &sr, err -} - -type postRoutesParams struct { - Routes []netip.Prefix `json:"routes"` -} - -// SetRoutes updates the list of subnets that are enabled for a device. -// Subnets must be parsable by net/netip.ParsePrefix. -// Subnets do not have to be currently advertised by a device, they may be pre-enabled. -// Returns the updated list of enabled and advertised subnet routes in a *Routes object. -func (c *Client) SetRoutes(ctx context.Context, deviceID string, subnets []netip.Prefix) (routes *Routes, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.SetRoutes: %w", err) - } - }() - params := &postRoutesParams{Routes: subnets} - data, err := json.Marshal(params) - if err != nil { - return nil, err - } - path := fmt.Sprintf("%s/api/v2/device/%s/routes", c.baseURL(), deviceID) - req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewBuffer(data)) - if err != nil { - return nil, err - } - - b, resp, err := c.sendRequest(req) - if err != nil { - return nil, err - } - // If status code was not successful, return the error. - // TODO: Change the check for the StatusCode to include other 2XX success codes. - if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) - } - - var srr *Routes - if err := json.Unmarshal(b, &srr); err != nil { - return nil, err - } - return srr, err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +package tailscale + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/netip" +) + +// Routes contains the lists of subnet routes that are currently advertised by a device, +// as well as the subnets that are enabled to be routed by the device. +type Routes struct { + AdvertisedRoutes []netip.Prefix `json:"advertisedRoutes"` + EnabledRoutes []netip.Prefix `json:"enabledRoutes"` +} + +// Routes retrieves the list of subnet routes that have been enabled for a device. +// The routes that are returned are not necessarily advertised by the device, +// they have only been preapproved. +func (c *Client) Routes(ctx context.Context, deviceID string) (routes *Routes, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.Routes: %w", err) + } + }() + + path := fmt.Sprintf("%s/api/v2/device/%s/routes", c.baseURL(), deviceID) + req, err := http.NewRequestWithContext(ctx, "GET", path, nil) + if err != nil { + return nil, err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return nil, err + } + // If status code was not successful, return the error. + // TODO: Change the check for the StatusCode to include other 2XX success codes. + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(b, resp) + } + + var sr Routes + err = json.Unmarshal(b, &sr) + return &sr, err +} + +type postRoutesParams struct { + Routes []netip.Prefix `json:"routes"` +} + +// SetRoutes updates the list of subnets that are enabled for a device. +// Subnets must be parsable by net/netip.ParsePrefix. +// Subnets do not have to be currently advertised by a device, they may be pre-enabled. +// Returns the updated list of enabled and advertised subnet routes in a *Routes object. +func (c *Client) SetRoutes(ctx context.Context, deviceID string, subnets []netip.Prefix) (routes *Routes, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.SetRoutes: %w", err) + } + }() + params := &postRoutesParams{Routes: subnets} + data, err := json.Marshal(params) + if err != nil { + return nil, err + } + path := fmt.Sprintf("%s/api/v2/device/%s/routes", c.baseURL(), deviceID) + req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + + b, resp, err := c.sendRequest(req) + if err != nil { + return nil, err + } + // If status code was not successful, return the error. + // TODO: Change the check for the StatusCode to include other 2XX success codes. + if resp.StatusCode != http.StatusOK { + return nil, handleErrorResponse(b, resp) + } + + var srr *Routes + if err := json.Unmarshal(b, &srr); err != nil { + return nil, err + } + return srr, err +} diff --git a/client/tailscale/tailnet.go b/client/tailscale/tailnet.go index eef2dca20..2539e7f23 100644 --- a/client/tailscale/tailnet.go +++ b/client/tailscale/tailnet.go @@ -1,42 +1,42 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 - -package tailscale - -import ( - "context" - "fmt" - "net/http" - "net/url" - - "tailscale.com/util/httpm" -) - -// TailnetDeleteRequest handles sending a DELETE request for a tailnet to control. -func (c *Client) TailnetDeleteRequest(ctx context.Context, tailnetID string) (err error) { - defer func() { - if err != nil { - err = fmt.Errorf("tailscale.DeleteTailnet: %w", err) - } - }() - - path := fmt.Sprintf("%s/api/v2/tailnet/%s", c.baseURL(), url.PathEscape(string(tailnetID))) - req, err := http.NewRequestWithContext(ctx, httpm.DELETE, path, nil) - if err != nil { - return err - } - - c.setAuth(req) - b, resp, err := c.sendRequest(req) - if err != nil { - return err - } - - if resp.StatusCode != http.StatusOK { - return handleErrorResponse(b, resp) - } - - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +package tailscale + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "tailscale.com/util/httpm" +) + +// TailnetDeleteRequest handles sending a DELETE request for a tailnet to control. +func (c *Client) TailnetDeleteRequest(ctx context.Context, tailnetID string) (err error) { + defer func() { + if err != nil { + err = fmt.Errorf("tailscale.DeleteTailnet: %w", err) + } + }() + + path := fmt.Sprintf("%s/api/v2/tailnet/%s", c.baseURL(), url.PathEscape(string(tailnetID))) + req, err := http.NewRequestWithContext(ctx, httpm.DELETE, path, nil) + if err != nil { + return err + } + + c.setAuth(req) + b, resp, err := c.sendRequest(req) + if err != nil { + return err + } + + if resp.StatusCode != http.StatusOK { + return handleErrorResponse(b, resp) + } + + return nil +} diff --git a/client/web/qnap.go b/client/web/qnap.go index 8fa5ee174..9bde64bf5 100644 --- a/client/web/qnap.go +++ b/client/web/qnap.go @@ -1,127 +1,127 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// qnap.go contains handlers and logic, such as authentication, -// that is specific to running the web client on QNAP. - -package web - -import ( - "crypto/tls" - "encoding/xml" - "errors" - "fmt" - "io" - "log" - "net/http" - "net/url" -) - -// authorizeQNAP authenticates the logged-in QNAP user and verifies that they -// are authorized to use the web client. -// If the user is not authorized to use the client, an error is returned. -func authorizeQNAP(r *http.Request) (authorized bool, err error) { - _, resp, err := qnapAuthn(r) - if err != nil { - return false, err - } - if resp.IsAdmin == 0 { - return false, errors.New("user is not an admin") - } - - return true, nil -} - -type qnapAuthResponse struct { - AuthPassed int `xml:"authPassed"` - IsAdmin int `xml:"isAdmin"` - AuthSID string `xml:"authSid"` - ErrorValue int `xml:"errorValue"` -} - -func qnapAuthn(r *http.Request) (string, *qnapAuthResponse, error) { - user, err := r.Cookie("NAS_USER") - if err != nil { - return "", nil, err - } - token, err := r.Cookie("qtoken") - if err == nil { - return qnapAuthnQtoken(r, user.Value, token.Value) - } - sid, err := r.Cookie("NAS_SID") - if err == nil { - return qnapAuthnSid(r, user.Value, sid.Value) - } - return "", nil, fmt.Errorf("not authenticated by any mechanism") -} - -// qnapAuthnURL returns the auth URL to use by inferring where the UI is -// running based on the request URL. This is necessary because QNAP has so -// many options, see https://github.com/tailscale/tailscale/issues/7108 -// and https://github.com/tailscale/tailscale/issues/6903 -func qnapAuthnURL(requestUrl string, query url.Values) string { - in, err := url.Parse(requestUrl) - scheme := "" - host := "" - if err != nil || in.Scheme == "" { - log.Printf("Cannot parse QNAP login URL %v", err) - - // try localhost and hope for the best - scheme = "http" - host = "localhost" - } else { - scheme = in.Scheme - host = in.Host - } - - u := url.URL{ - Scheme: scheme, - Host: host, - Path: "/cgi-bin/authLogin.cgi", - RawQuery: query.Encode(), - } - - return u.String() -} - -func qnapAuthnQtoken(r *http.Request, user, token string) (string, *qnapAuthResponse, error) { - query := url.Values{ - "qtoken": []string{token}, - "user": []string{user}, - } - return qnapAuthnFinish(user, qnapAuthnURL(r.URL.String(), query)) -} - -func qnapAuthnSid(r *http.Request, user, sid string) (string, *qnapAuthResponse, error) { - query := url.Values{ - "sid": []string{sid}, - } - return qnapAuthnFinish(user, qnapAuthnURL(r.URL.String(), query)) -} - -func qnapAuthnFinish(user, url string) (string, *qnapAuthResponse, error) { - // QNAP Force HTTPS mode uses a self-signed certificate. Even importing - // the QNAP root CA isn't enough, the cert doesn't have a usable CN nor - // SAN. See https://github.com/tailscale/tailscale/issues/6903 - tr := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - client := &http.Client{Transport: tr} - resp, err := client.Get(url) - if err != nil { - return "", nil, err - } - defer resp.Body.Close() - out, err := io.ReadAll(resp.Body) - if err != nil { - return "", nil, err - } - authResp := &qnapAuthResponse{} - if err := xml.Unmarshal(out, authResp); err != nil { - return "", nil, err - } - if authResp.AuthPassed == 0 { - return "", nil, fmt.Errorf("not authenticated") - } - return user, authResp, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// qnap.go contains handlers and logic, such as authentication, +// that is specific to running the web client on QNAP. + +package web + +import ( + "crypto/tls" + "encoding/xml" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/url" +) + +// authorizeQNAP authenticates the logged-in QNAP user and verifies that they +// are authorized to use the web client. +// If the user is not authorized to use the client, an error is returned. +func authorizeQNAP(r *http.Request) (authorized bool, err error) { + _, resp, err := qnapAuthn(r) + if err != nil { + return false, err + } + if resp.IsAdmin == 0 { + return false, errors.New("user is not an admin") + } + + return true, nil +} + +type qnapAuthResponse struct { + AuthPassed int `xml:"authPassed"` + IsAdmin int `xml:"isAdmin"` + AuthSID string `xml:"authSid"` + ErrorValue int `xml:"errorValue"` +} + +func qnapAuthn(r *http.Request) (string, *qnapAuthResponse, error) { + user, err := r.Cookie("NAS_USER") + if err != nil { + return "", nil, err + } + token, err := r.Cookie("qtoken") + if err == nil { + return qnapAuthnQtoken(r, user.Value, token.Value) + } + sid, err := r.Cookie("NAS_SID") + if err == nil { + return qnapAuthnSid(r, user.Value, sid.Value) + } + return "", nil, fmt.Errorf("not authenticated by any mechanism") +} + +// qnapAuthnURL returns the auth URL to use by inferring where the UI is +// running based on the request URL. This is necessary because QNAP has so +// many options, see https://github.com/tailscale/tailscale/issues/7108 +// and https://github.com/tailscale/tailscale/issues/6903 +func qnapAuthnURL(requestUrl string, query url.Values) string { + in, err := url.Parse(requestUrl) + scheme := "" + host := "" + if err != nil || in.Scheme == "" { + log.Printf("Cannot parse QNAP login URL %v", err) + + // try localhost and hope for the best + scheme = "http" + host = "localhost" + } else { + scheme = in.Scheme + host = in.Host + } + + u := url.URL{ + Scheme: scheme, + Host: host, + Path: "/cgi-bin/authLogin.cgi", + RawQuery: query.Encode(), + } + + return u.String() +} + +func qnapAuthnQtoken(r *http.Request, user, token string) (string, *qnapAuthResponse, error) { + query := url.Values{ + "qtoken": []string{token}, + "user": []string{user}, + } + return qnapAuthnFinish(user, qnapAuthnURL(r.URL.String(), query)) +} + +func qnapAuthnSid(r *http.Request, user, sid string) (string, *qnapAuthResponse, error) { + query := url.Values{ + "sid": []string{sid}, + } + return qnapAuthnFinish(user, qnapAuthnURL(r.URL.String(), query)) +} + +func qnapAuthnFinish(user, url string) (string, *qnapAuthResponse, error) { + // QNAP Force HTTPS mode uses a self-signed certificate. Even importing + // the QNAP root CA isn't enough, the cert doesn't have a usable CN nor + // SAN. See https://github.com/tailscale/tailscale/issues/6903 + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + client := &http.Client{Transport: tr} + resp, err := client.Get(url) + if err != nil { + return "", nil, err + } + defer resp.Body.Close() + out, err := io.ReadAll(resp.Body) + if err != nil { + return "", nil, err + } + authResp := &qnapAuthResponse{} + if err := xml.Unmarshal(out, authResp); err != nil { + return "", nil, err + } + if authResp.AuthPassed == 0 { + return "", nil, fmt.Errorf("not authenticated") + } + return user, authResp, nil +} diff --git a/client/web/src/assets/icons/arrow-right.svg b/client/web/src/assets/icons/arrow-right.svg index 0a32ef484..fbc4bb7ae 100644 --- a/client/web/src/assets/icons/arrow-right.svg +++ b/client/web/src/assets/icons/arrow-right.svg @@ -1,4 +1,4 @@ - - - - + + + + diff --git a/client/web/src/assets/icons/arrow-up-circle.svg b/client/web/src/assets/icons/arrow-up-circle.svg index e64c836be..e9d009eb6 100644 --- a/client/web/src/assets/icons/arrow-up-circle.svg +++ b/client/web/src/assets/icons/arrow-up-circle.svg @@ -1,5 +1,5 @@ - - - - - + + + + + diff --git a/client/web/src/assets/icons/check-circle.svg b/client/web/src/assets/icons/check-circle.svg index 6c5ee519e..4daeed514 100644 --- a/client/web/src/assets/icons/check-circle.svg +++ b/client/web/src/assets/icons/check-circle.svg @@ -1,4 +1,4 @@ - - - - + + + + diff --git a/client/web/src/assets/icons/check.svg b/client/web/src/assets/icons/check.svg index 70027536a..efa11685d 100644 --- a/client/web/src/assets/icons/check.svg +++ b/client/web/src/assets/icons/check.svg @@ -1,3 +1,3 @@ - - - + + + diff --git a/client/web/src/assets/icons/chevron-down.svg b/client/web/src/assets/icons/chevron-down.svg index 993744c2f..afc98f255 100644 --- a/client/web/src/assets/icons/chevron-down.svg +++ b/client/web/src/assets/icons/chevron-down.svg @@ -1,3 +1,3 @@ - - - + + + diff --git a/client/web/src/assets/icons/eye.svg b/client/web/src/assets/icons/eye.svg index e27767477..b0b21ed3f 100644 --- a/client/web/src/assets/icons/eye.svg +++ b/client/web/src/assets/icons/eye.svg @@ -1,11 +1,11 @@ - - - - - - - - - - - + + + + + + + + + + + diff --git a/client/web/src/assets/icons/search.svg b/client/web/src/assets/icons/search.svg index 08eb2d3dc..782cd90ee 100644 --- a/client/web/src/assets/icons/search.svg +++ b/client/web/src/assets/icons/search.svg @@ -1,4 +1,4 @@ - - - - + + + + diff --git a/client/web/src/assets/icons/tailscale-icon.svg b/client/web/src/assets/icons/tailscale-icon.svg index de3c975ce..d6052fe5e 100644 --- a/client/web/src/assets/icons/tailscale-icon.svg +++ b/client/web/src/assets/icons/tailscale-icon.svg @@ -1,18 +1,18 @@ - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + diff --git a/client/web/src/assets/icons/tailscale-logo.svg b/client/web/src/assets/icons/tailscale-logo.svg index 94a9cc4ee..6d5c7ce0c 100644 --- a/client/web/src/assets/icons/tailscale-logo.svg +++ b/client/web/src/assets/icons/tailscale-logo.svg @@ -1,20 +1,20 @@ - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + diff --git a/client/web/src/assets/icons/user.svg b/client/web/src/assets/icons/user.svg index 7fa3d2603..29d86f049 100644 --- a/client/web/src/assets/icons/user.svg +++ b/client/web/src/assets/icons/user.svg @@ -1,4 +1,4 @@ - - - - + + + + diff --git a/client/web/src/assets/icons/x-circle.svg b/client/web/src/assets/icons/x-circle.svg index d6259c917..49afc5a03 100644 --- a/client/web/src/assets/icons/x-circle.svg +++ b/client/web/src/assets/icons/x-circle.svg @@ -1,5 +1,5 @@ - - - - - + + + + + diff --git a/client/web/synology.go b/client/web/synology.go index 548026383..922489d78 100644 --- a/client/web/synology.go +++ b/client/web/synology.go @@ -1,59 +1,59 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// synology.go contains handlers and logic, such as authentication, -// that is specific to running the web client on Synology. - -package web - -import ( - "errors" - "fmt" - "net/http" - "os/exec" - "strings" - - "tailscale.com/util/groupmember" -) - -// authorizeSynology authenticates the logged-in Synology user and verifies -// that they are authorized to use the web client. -// If the user is authenticated, but not authorized to use the client, an error is returned. -func authorizeSynology(r *http.Request) (authorized bool, err error) { - if !hasSynoToken(r) { - return false, nil - } - - // authenticate the Synology user - cmd := exec.Command("/usr/syno/synoman/webman/modules/authenticate.cgi") - out, err := cmd.CombinedOutput() - if err != nil { - return false, fmt.Errorf("auth: %v: %s", err, out) - } - user := strings.TrimSpace(string(out)) - - // check if the user is in the administrators group - isAdmin, err := groupmember.IsMemberOfGroup("administrators", user) - if err != nil { - return false, err - } - if !isAdmin { - return false, errors.New("not a member of administrators group") - } - - return true, nil -} - -// hasSynoToken returns true if the request include a SynoToken used for synology auth. -func hasSynoToken(r *http.Request) bool { - if r.Header.Get("X-Syno-Token") != "" { - return true - } - if r.URL.Query().Get("SynoToken") != "" { - return true - } - if r.Method == "POST" && r.FormValue("SynoToken") != "" { - return true - } - return false -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// synology.go contains handlers and logic, such as authentication, +// that is specific to running the web client on Synology. + +package web + +import ( + "errors" + "fmt" + "net/http" + "os/exec" + "strings" + + "tailscale.com/util/groupmember" +) + +// authorizeSynology authenticates the logged-in Synology user and verifies +// that they are authorized to use the web client. +// If the user is authenticated, but not authorized to use the client, an error is returned. +func authorizeSynology(r *http.Request) (authorized bool, err error) { + if !hasSynoToken(r) { + return false, nil + } + + // authenticate the Synology user + cmd := exec.Command("/usr/syno/synoman/webman/modules/authenticate.cgi") + out, err := cmd.CombinedOutput() + if err != nil { + return false, fmt.Errorf("auth: %v: %s", err, out) + } + user := strings.TrimSpace(string(out)) + + // check if the user is in the administrators group + isAdmin, err := groupmember.IsMemberOfGroup("administrators", user) + if err != nil { + return false, err + } + if !isAdmin { + return false, errors.New("not a member of administrators group") + } + + return true, nil +} + +// hasSynoToken returns true if the request include a SynoToken used for synology auth. +func hasSynoToken(r *http.Request) bool { + if r.Header.Get("X-Syno-Token") != "" { + return true + } + if r.URL.Query().Get("SynoToken") != "" { + return true + } + if r.Method == "POST" && r.FormValue("SynoToken") != "" { + return true + } + return false +} diff --git a/clientupdate/distsign/distsign.go b/clientupdate/distsign/distsign.go index aae620153..eba4b9267 100644 --- a/clientupdate/distsign/distsign.go +++ b/clientupdate/distsign/distsign.go @@ -1,486 +1,486 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package distsign implements signature and validation of arbitrary -// distributable files. -// -// There are 3 parties in this exchange: -// - builder, which creates files, signs them with signing keys and publishes -// to server -// - server, which distributes public signing keys, files and signatures -// - client, which downloads files and signatures from server, and validates -// the signatures -// -// There are 2 types of keys: -// - signing keys, that sign individual distributable files on the builder -// - root keys, that sign signing keys and are kept offline -// -// root keys -(sign)-> signing keys -(sign)-> files -// -// All keys are asymmetric Ed25519 key pairs. -// -// The server serves static files under some known prefix. The kinds of files are: -// - distsign.pub - bundle of PEM-encoded public signing keys -// - distsign.pub.sig - signature of distsign.pub using one of the root keys -// - $file - any distributable file -// - $file.sig - signature of $file using any of the signing keys -// -// The root public keys are baked into the client software at compile time. -// These keys are long-lived and prove the validity of current signing keys -// from distsign.pub. To rotate root keys, a new client release must be -// published, they are not rotated dynamically. There are multiple root keys in -// different locations specifically to allow this rotation without using the -// discarded root key for any new signatures. -// -// The signing public keys are fetched by the client dynamically before every -// download and can be rotated more readily, assuming that most deployed -// clients trust the root keys used to issue fresh signing keys. -package distsign - -import ( - "context" - "crypto/ed25519" - "crypto/rand" - "encoding/binary" - "encoding/pem" - "errors" - "fmt" - "hash" - "io" - "log" - "net/http" - "net/url" - "os" - "time" - - "github.com/hdevalence/ed25519consensus" - "golang.org/x/crypto/blake2s" - "tailscale.com/net/tshttpproxy" - "tailscale.com/types/logger" - "tailscale.com/util/httpm" - "tailscale.com/util/must" -) - -const ( - pemTypeRootPrivate = "ROOT PRIVATE KEY" - pemTypeRootPublic = "ROOT PUBLIC KEY" - pemTypeSigningPrivate = "SIGNING PRIVATE KEY" - pemTypeSigningPublic = "SIGNING PUBLIC KEY" - - downloadSizeLimit = 1 << 29 // 512MB - signingKeysSizeLimit = 1 << 20 // 1MB - signatureSizeLimit = ed25519.SignatureSize -) - -// RootKey is a root key used to sign signing keys. -type RootKey struct { - k ed25519.PrivateKey -} - -// GenerateRootKey generates a new root key pair and encodes it as PEM. -func GenerateRootKey() (priv, pub []byte, err error) { - pub, priv, err = ed25519.GenerateKey(rand.Reader) - if err != nil { - return nil, nil, err - } - return pem.EncodeToMemory(&pem.Block{ - Type: pemTypeRootPrivate, - Bytes: []byte(priv), - }), pem.EncodeToMemory(&pem.Block{ - Type: pemTypeRootPublic, - Bytes: []byte(pub), - }), nil -} - -// ParseRootKey parses the PEM-encoded private root key. The key must be in the -// same format as returned by GenerateRootKey. -func ParseRootKey(privKey []byte) (*RootKey, error) { - k, err := parsePrivateKey(privKey, pemTypeRootPrivate) - if err != nil { - return nil, fmt.Errorf("failed to parse root key: %w", err) - } - return &RootKey{k: k}, nil -} - -// SignSigningKeys signs the bundle of public signing keys. The bundle must be -// a sequence of PEM blocks joined with newlines. -func (r *RootKey) SignSigningKeys(pubBundle []byte) ([]byte, error) { - if _, err := ParseSigningKeyBundle(pubBundle); err != nil { - return nil, err - } - return ed25519.Sign(r.k, pubBundle), nil -} - -// SigningKey is a signing key used to sign packages. -type SigningKey struct { - k ed25519.PrivateKey -} - -// GenerateSigningKey generates a new signing key pair and encodes it as PEM. -func GenerateSigningKey() (priv, pub []byte, err error) { - pub, priv, err = ed25519.GenerateKey(rand.Reader) - if err != nil { - return nil, nil, err - } - return pem.EncodeToMemory(&pem.Block{ - Type: pemTypeSigningPrivate, - Bytes: []byte(priv), - }), pem.EncodeToMemory(&pem.Block{ - Type: pemTypeSigningPublic, - Bytes: []byte(pub), - }), nil -} - -// ParseSigningKey parses the PEM-encoded private signing key. The key must be -// in the same format as returned by GenerateSigningKey. -func ParseSigningKey(privKey []byte) (*SigningKey, error) { - k, err := parsePrivateKey(privKey, pemTypeSigningPrivate) - if err != nil { - return nil, fmt.Errorf("failed to parse root key: %w", err) - } - return &SigningKey{k: k}, nil -} - -// SignPackageHash signs the hash and the length of a package. Use PackageHash -// to compute the inputs. -func (s *SigningKey) SignPackageHash(hash []byte, len int64) ([]byte, error) { - if len <= 0 { - return nil, fmt.Errorf("package length must be positive, got %d", len) - } - msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) - return ed25519.Sign(s.k, msg), nil -} - -// PackageHash is a hash.Hash that counts the number of bytes written. Use it -// to get the hash and length inputs to SigningKey.SignPackageHash. -type PackageHash struct { - hash.Hash - len int64 -} - -// NewPackageHash returns an initialized PackageHash using BLAKE2s. -func NewPackageHash() *PackageHash { - h, err := blake2s.New256(nil) - if err != nil { - // Should never happen with a nil key passed to blake2s. - panic(err) - } - return &PackageHash{Hash: h} -} - -func (ph *PackageHash) Write(b []byte) (int, error) { - ph.len += int64(len(b)) - return ph.Hash.Write(b) -} - -// Reset the PackageHash to its initial state. -func (ph *PackageHash) Reset() { - ph.len = 0 - ph.Hash.Reset() -} - -// Len returns the total number of bytes written. -func (ph *PackageHash) Len() int64 { return ph.len } - -// Client downloads and validates files from a distribution server. -type Client struct { - logf logger.Logf - roots []ed25519.PublicKey - pkgsAddr *url.URL -} - -// NewClient returns a new client for distribution server located at pkgsAddr, -// and uses embedded root keys from the roots/ subdirectory of this package. -func NewClient(logf logger.Logf, pkgsAddr string) (*Client, error) { - if logf == nil { - logf = log.Printf - } - u, err := url.Parse(pkgsAddr) - if err != nil { - return nil, fmt.Errorf("invalid pkgsAddr %q: %w", pkgsAddr, err) - } - return &Client{logf: logf, roots: roots(), pkgsAddr: u}, nil -} - -func (c *Client) url(path string) string { - return c.pkgsAddr.JoinPath(path).String() -} - -// Download fetches a file at path srcPath from pkgsAddr passed in NewClient. -// The file is downloaded to dstPath and its signature is validated using the -// embedded root keys. Download returns an error if anything goes wrong with -// the actual file download or with signature validation. -func (c *Client) Download(ctx context.Context, srcPath, dstPath string) error { - // Always fetch a fresh signing key. - sigPub, err := c.signingKeys() - if err != nil { - return err - } - - srcURL := c.url(srcPath) - sigURL := srcURL + ".sig" - - c.logf("Downloading %q", srcURL) - dstPathUnverified := dstPath + ".unverified" - hash, len, err := c.download(ctx, srcURL, dstPathUnverified, downloadSizeLimit) - if err != nil { - return err - } - c.logf("Downloading %q", sigURL) - sig, err := fetch(sigURL, signatureSizeLimit) - if err != nil { - // Best-effort clean up of downloaded package. - os.Remove(dstPathUnverified) - return err - } - msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) - if !VerifyAny(sigPub, msg, sig) { - // Best-effort clean up of downloaded package. - os.Remove(dstPathUnverified) - return fmt.Errorf("signature %q for file %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key", sigURL, srcURL) - } - c.logf("Signature OK") - - if err := os.Rename(dstPathUnverified, dstPath); err != nil { - return fmt.Errorf("failed to move %q to %q after signature validation", dstPathUnverified, dstPath) - } - - return nil -} - -// ValidateLocalBinary fetches the latest signature associated with the binary -// at srcURLPath and uses it to validate the file located on disk via -// localFilePath. ValidateLocalBinary returns an error if anything goes wrong -// with the signature download or with signature validation. -func (c *Client) ValidateLocalBinary(srcURLPath, localFilePath string) error { - // Always fetch a fresh signing key. - sigPub, err := c.signingKeys() - if err != nil { - return err - } - - srcURL := c.url(srcURLPath) - sigURL := srcURL + ".sig" - - localFile, err := os.Open(localFilePath) - if err != nil { - return err - } - defer localFile.Close() - - h := NewPackageHash() - _, err = io.Copy(h, localFile) - if err != nil { - return err - } - hash, hashLen := h.Sum(nil), h.Len() - - c.logf("Downloading %q", sigURL) - sig, err := fetch(sigURL, signatureSizeLimit) - if err != nil { - return err - } - - msg := binary.LittleEndian.AppendUint64(hash, uint64(hashLen)) - if !VerifyAny(sigPub, msg, sig) { - return fmt.Errorf("signature %q for file %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key", sigURL, localFilePath) - } - c.logf("Signature OK") - - return nil -} - -// signingKeys fetches current signing keys from the server and validates them -// against the roots. Should be called before validation of any downloaded file -// to get the fresh keys. -func (c *Client) signingKeys() ([]ed25519.PublicKey, error) { - keyURL := c.url("distsign.pub") - sigURL := keyURL + ".sig" - raw, err := fetch(keyURL, signingKeysSizeLimit) - if err != nil { - return nil, err - } - sig, err := fetch(sigURL, signatureSizeLimit) - if err != nil { - return nil, err - } - if !VerifyAny(c.roots, raw, sig) { - return nil, fmt.Errorf("signature %q for key %q does not validate with any known root key; either you are under attack, or running a very old version of Tailscale with outdated root keys", sigURL, keyURL) - } - - keys, err := ParseSigningKeyBundle(raw) - if err != nil { - return nil, fmt.Errorf("cannot parse signing key bundle from %q: %w", keyURL, err) - } - return keys, nil -} - -// fetch reads the response body from url into memory, up to limit bytes. -func fetch(url string, limit int64) ([]byte, error) { - resp, err := http.Get(url) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - return io.ReadAll(io.LimitReader(resp.Body, limit)) -} - -// download writes the response body of url into a local file at dst, up to -// limit bytes. On success, the returned value is a BLAKE2s hash of the file. -func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) { - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.Proxy = tshttpproxy.ProxyFromEnvironment - defer tr.CloseIdleConnections() - hc := &http.Client{Transport: tr} - - quickCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - headReq := must.Get(http.NewRequestWithContext(quickCtx, httpm.HEAD, url, nil)) - - res, err := hc.Do(headReq) - if err != nil { - return nil, 0, err - } - if res.StatusCode != http.StatusOK { - return nil, 0, fmt.Errorf("HEAD %q: %v", url, res.Status) - } - if res.ContentLength <= 0 { - return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength) - } - c.logf("Download size: %v", res.ContentLength) - - dlReq := must.Get(http.NewRequestWithContext(ctx, httpm.GET, url, nil)) - dlRes, err := hc.Do(dlReq) - if err != nil { - return nil, 0, err - } - defer dlRes.Body.Close() - // TODO(bradfitz): resume from existing partial file on disk - if dlRes.StatusCode != http.StatusOK { - return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status) - } - - of, err := os.Create(dst) - if err != nil { - return nil, 0, err - } - defer of.Close() - pw := &progressWriter{total: res.ContentLength, logf: c.logf} - h := NewPackageHash() - n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit)) - if err != nil { - return nil, n, err - } - if n != res.ContentLength { - return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, res.ContentLength) - } - if err := dlRes.Body.Close(); err != nil { - return nil, n, err - } - if err := of.Close(); err != nil { - return nil, n, err - } - pw.print() - - return h.Sum(nil), h.Len(), nil -} - -type progressWriter struct { - done int64 - total int64 - lastPrint time.Time - logf logger.Logf -} - -func (pw *progressWriter) Write(p []byte) (n int, err error) { - pw.done += int64(len(p)) - if time.Since(pw.lastPrint) > 2*time.Second { - pw.print() - } - return len(p), nil -} - -func (pw *progressWriter) print() { - pw.lastPrint = time.Now() - pw.logf("Downloaded %v/%v (%.1f%%)", pw.done, pw.total, float64(pw.done)/float64(pw.total)*100) -} - -func parsePrivateKey(data []byte, typeTag string) (ed25519.PrivateKey, error) { - b, rest := pem.Decode(data) - if b == nil { - return nil, errors.New("failed to decode PEM data") - } - if len(rest) > 0 { - return nil, errors.New("trailing PEM data") - } - if b.Type != typeTag { - return nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) - } - if len(b.Bytes) != ed25519.PrivateKeySize { - return nil, errors.New("private key has incorrect length for an Ed25519 private key") - } - return ed25519.PrivateKey(b.Bytes), nil -} - -// ParseSigningKeyBundle parses the bundle of PEM-encoded public signing keys. -func ParseSigningKeyBundle(bundle []byte) ([]ed25519.PublicKey, error) { - return parsePublicKeyBundle(bundle, pemTypeSigningPublic) -} - -// ParseRootKeyBundle parses the bundle of PEM-encoded public root keys. -func ParseRootKeyBundle(bundle []byte) ([]ed25519.PublicKey, error) { - return parsePublicKeyBundle(bundle, pemTypeRootPublic) -} - -func parsePublicKeyBundle(bundle []byte, typeTag string) ([]ed25519.PublicKey, error) { - var keys []ed25519.PublicKey - for len(bundle) > 0 { - pub, rest, err := parsePublicKey(bundle, typeTag) - if err != nil { - return nil, err - } - keys = append(keys, pub) - bundle = rest - } - if len(keys) == 0 { - return nil, errors.New("no signing keys found in the bundle") - } - return keys, nil -} - -func parseSinglePublicKey(data []byte, typeTag string) (ed25519.PublicKey, error) { - pub, rest, err := parsePublicKey(data, typeTag) - if err != nil { - return nil, err - } - if len(rest) > 0 { - return nil, errors.New("trailing PEM data") - } - return pub, err -} - -func parsePublicKey(data []byte, typeTag string) (pub ed25519.PublicKey, rest []byte, retErr error) { - b, rest := pem.Decode(data) - if b == nil { - return nil, nil, errors.New("failed to decode PEM data") - } - if b.Type != typeTag { - return nil, nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) - } - if len(b.Bytes) != ed25519.PublicKeySize { - return nil, nil, errors.New("public key has incorrect length for an Ed25519 public key") - } - return ed25519.PublicKey(b.Bytes), rest, nil -} - -// VerifyAny verifies whether sig is valid for msg using any of the keys. -// VerifyAny will panic if any of the keys have the wrong size for Ed25519. -func VerifyAny(keys []ed25519.PublicKey, msg, sig []byte) bool { - for _, k := range keys { - if ed25519consensus.Verify(k, msg, sig) { - return true - } - } - return false -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package distsign implements signature and validation of arbitrary +// distributable files. +// +// There are 3 parties in this exchange: +// - builder, which creates files, signs them with signing keys and publishes +// to server +// - server, which distributes public signing keys, files and signatures +// - client, which downloads files and signatures from server, and validates +// the signatures +// +// There are 2 types of keys: +// - signing keys, that sign individual distributable files on the builder +// - root keys, that sign signing keys and are kept offline +// +// root keys -(sign)-> signing keys -(sign)-> files +// +// All keys are asymmetric Ed25519 key pairs. +// +// The server serves static files under some known prefix. The kinds of files are: +// - distsign.pub - bundle of PEM-encoded public signing keys +// - distsign.pub.sig - signature of distsign.pub using one of the root keys +// - $file - any distributable file +// - $file.sig - signature of $file using any of the signing keys +// +// The root public keys are baked into the client software at compile time. +// These keys are long-lived and prove the validity of current signing keys +// from distsign.pub. To rotate root keys, a new client release must be +// published, they are not rotated dynamically. There are multiple root keys in +// different locations specifically to allow this rotation without using the +// discarded root key for any new signatures. +// +// The signing public keys are fetched by the client dynamically before every +// download and can be rotated more readily, assuming that most deployed +// clients trust the root keys used to issue fresh signing keys. +package distsign + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/binary" + "encoding/pem" + "errors" + "fmt" + "hash" + "io" + "log" + "net/http" + "net/url" + "os" + "time" + + "github.com/hdevalence/ed25519consensus" + "golang.org/x/crypto/blake2s" + "tailscale.com/net/tshttpproxy" + "tailscale.com/types/logger" + "tailscale.com/util/httpm" + "tailscale.com/util/must" +) + +const ( + pemTypeRootPrivate = "ROOT PRIVATE KEY" + pemTypeRootPublic = "ROOT PUBLIC KEY" + pemTypeSigningPrivate = "SIGNING PRIVATE KEY" + pemTypeSigningPublic = "SIGNING PUBLIC KEY" + + downloadSizeLimit = 1 << 29 // 512MB + signingKeysSizeLimit = 1 << 20 // 1MB + signatureSizeLimit = ed25519.SignatureSize +) + +// RootKey is a root key used to sign signing keys. +type RootKey struct { + k ed25519.PrivateKey +} + +// GenerateRootKey generates a new root key pair and encodes it as PEM. +func GenerateRootKey() (priv, pub []byte, err error) { + pub, priv, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + return pem.EncodeToMemory(&pem.Block{ + Type: pemTypeRootPrivate, + Bytes: []byte(priv), + }), pem.EncodeToMemory(&pem.Block{ + Type: pemTypeRootPublic, + Bytes: []byte(pub), + }), nil +} + +// ParseRootKey parses the PEM-encoded private root key. The key must be in the +// same format as returned by GenerateRootKey. +func ParseRootKey(privKey []byte) (*RootKey, error) { + k, err := parsePrivateKey(privKey, pemTypeRootPrivate) + if err != nil { + return nil, fmt.Errorf("failed to parse root key: %w", err) + } + return &RootKey{k: k}, nil +} + +// SignSigningKeys signs the bundle of public signing keys. The bundle must be +// a sequence of PEM blocks joined with newlines. +func (r *RootKey) SignSigningKeys(pubBundle []byte) ([]byte, error) { + if _, err := ParseSigningKeyBundle(pubBundle); err != nil { + return nil, err + } + return ed25519.Sign(r.k, pubBundle), nil +} + +// SigningKey is a signing key used to sign packages. +type SigningKey struct { + k ed25519.PrivateKey +} + +// GenerateSigningKey generates a new signing key pair and encodes it as PEM. +func GenerateSigningKey() (priv, pub []byte, err error) { + pub, priv, err = ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + return pem.EncodeToMemory(&pem.Block{ + Type: pemTypeSigningPrivate, + Bytes: []byte(priv), + }), pem.EncodeToMemory(&pem.Block{ + Type: pemTypeSigningPublic, + Bytes: []byte(pub), + }), nil +} + +// ParseSigningKey parses the PEM-encoded private signing key. The key must be +// in the same format as returned by GenerateSigningKey. +func ParseSigningKey(privKey []byte) (*SigningKey, error) { + k, err := parsePrivateKey(privKey, pemTypeSigningPrivate) + if err != nil { + return nil, fmt.Errorf("failed to parse root key: %w", err) + } + return &SigningKey{k: k}, nil +} + +// SignPackageHash signs the hash and the length of a package. Use PackageHash +// to compute the inputs. +func (s *SigningKey) SignPackageHash(hash []byte, len int64) ([]byte, error) { + if len <= 0 { + return nil, fmt.Errorf("package length must be positive, got %d", len) + } + msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) + return ed25519.Sign(s.k, msg), nil +} + +// PackageHash is a hash.Hash that counts the number of bytes written. Use it +// to get the hash and length inputs to SigningKey.SignPackageHash. +type PackageHash struct { + hash.Hash + len int64 +} + +// NewPackageHash returns an initialized PackageHash using BLAKE2s. +func NewPackageHash() *PackageHash { + h, err := blake2s.New256(nil) + if err != nil { + // Should never happen with a nil key passed to blake2s. + panic(err) + } + return &PackageHash{Hash: h} +} + +func (ph *PackageHash) Write(b []byte) (int, error) { + ph.len += int64(len(b)) + return ph.Hash.Write(b) +} + +// Reset the PackageHash to its initial state. +func (ph *PackageHash) Reset() { + ph.len = 0 + ph.Hash.Reset() +} + +// Len returns the total number of bytes written. +func (ph *PackageHash) Len() int64 { return ph.len } + +// Client downloads and validates files from a distribution server. +type Client struct { + logf logger.Logf + roots []ed25519.PublicKey + pkgsAddr *url.URL +} + +// NewClient returns a new client for distribution server located at pkgsAddr, +// and uses embedded root keys from the roots/ subdirectory of this package. +func NewClient(logf logger.Logf, pkgsAddr string) (*Client, error) { + if logf == nil { + logf = log.Printf + } + u, err := url.Parse(pkgsAddr) + if err != nil { + return nil, fmt.Errorf("invalid pkgsAddr %q: %w", pkgsAddr, err) + } + return &Client{logf: logf, roots: roots(), pkgsAddr: u}, nil +} + +func (c *Client) url(path string) string { + return c.pkgsAddr.JoinPath(path).String() +} + +// Download fetches a file at path srcPath from pkgsAddr passed in NewClient. +// The file is downloaded to dstPath and its signature is validated using the +// embedded root keys. Download returns an error if anything goes wrong with +// the actual file download or with signature validation. +func (c *Client) Download(ctx context.Context, srcPath, dstPath string) error { + // Always fetch a fresh signing key. + sigPub, err := c.signingKeys() + if err != nil { + return err + } + + srcURL := c.url(srcPath) + sigURL := srcURL + ".sig" + + c.logf("Downloading %q", srcURL) + dstPathUnverified := dstPath + ".unverified" + hash, len, err := c.download(ctx, srcURL, dstPathUnverified, downloadSizeLimit) + if err != nil { + return err + } + c.logf("Downloading %q", sigURL) + sig, err := fetch(sigURL, signatureSizeLimit) + if err != nil { + // Best-effort clean up of downloaded package. + os.Remove(dstPathUnverified) + return err + } + msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) + if !VerifyAny(sigPub, msg, sig) { + // Best-effort clean up of downloaded package. + os.Remove(dstPathUnverified) + return fmt.Errorf("signature %q for file %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key", sigURL, srcURL) + } + c.logf("Signature OK") + + if err := os.Rename(dstPathUnverified, dstPath); err != nil { + return fmt.Errorf("failed to move %q to %q after signature validation", dstPathUnverified, dstPath) + } + + return nil +} + +// ValidateLocalBinary fetches the latest signature associated with the binary +// at srcURLPath and uses it to validate the file located on disk via +// localFilePath. ValidateLocalBinary returns an error if anything goes wrong +// with the signature download or with signature validation. +func (c *Client) ValidateLocalBinary(srcURLPath, localFilePath string) error { + // Always fetch a fresh signing key. + sigPub, err := c.signingKeys() + if err != nil { + return err + } + + srcURL := c.url(srcURLPath) + sigURL := srcURL + ".sig" + + localFile, err := os.Open(localFilePath) + if err != nil { + return err + } + defer localFile.Close() + + h := NewPackageHash() + _, err = io.Copy(h, localFile) + if err != nil { + return err + } + hash, hashLen := h.Sum(nil), h.Len() + + c.logf("Downloading %q", sigURL) + sig, err := fetch(sigURL, signatureSizeLimit) + if err != nil { + return err + } + + msg := binary.LittleEndian.AppendUint64(hash, uint64(hashLen)) + if !VerifyAny(sigPub, msg, sig) { + return fmt.Errorf("signature %q for file %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key", sigURL, localFilePath) + } + c.logf("Signature OK") + + return nil +} + +// signingKeys fetches current signing keys from the server and validates them +// against the roots. Should be called before validation of any downloaded file +// to get the fresh keys. +func (c *Client) signingKeys() ([]ed25519.PublicKey, error) { + keyURL := c.url("distsign.pub") + sigURL := keyURL + ".sig" + raw, err := fetch(keyURL, signingKeysSizeLimit) + if err != nil { + return nil, err + } + sig, err := fetch(sigURL, signatureSizeLimit) + if err != nil { + return nil, err + } + if !VerifyAny(c.roots, raw, sig) { + return nil, fmt.Errorf("signature %q for key %q does not validate with any known root key; either you are under attack, or running a very old version of Tailscale with outdated root keys", sigURL, keyURL) + } + + keys, err := ParseSigningKeyBundle(raw) + if err != nil { + return nil, fmt.Errorf("cannot parse signing key bundle from %q: %w", keyURL, err) + } + return keys, nil +} + +// fetch reads the response body from url into memory, up to limit bytes. +func fetch(url string, limit int64) ([]byte, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + return io.ReadAll(io.LimitReader(resp.Body, limit)) +} + +// download writes the response body of url into a local file at dst, up to +// limit bytes. On success, the returned value is a BLAKE2s hash of the file. +func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) { + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.Proxy = tshttpproxy.ProxyFromEnvironment + defer tr.CloseIdleConnections() + hc := &http.Client{Transport: tr} + + quickCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + headReq := must.Get(http.NewRequestWithContext(quickCtx, httpm.HEAD, url, nil)) + + res, err := hc.Do(headReq) + if err != nil { + return nil, 0, err + } + if res.StatusCode != http.StatusOK { + return nil, 0, fmt.Errorf("HEAD %q: %v", url, res.Status) + } + if res.ContentLength <= 0 { + return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength) + } + c.logf("Download size: %v", res.ContentLength) + + dlReq := must.Get(http.NewRequestWithContext(ctx, httpm.GET, url, nil)) + dlRes, err := hc.Do(dlReq) + if err != nil { + return nil, 0, err + } + defer dlRes.Body.Close() + // TODO(bradfitz): resume from existing partial file on disk + if dlRes.StatusCode != http.StatusOK { + return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status) + } + + of, err := os.Create(dst) + if err != nil { + return nil, 0, err + } + defer of.Close() + pw := &progressWriter{total: res.ContentLength, logf: c.logf} + h := NewPackageHash() + n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit)) + if err != nil { + return nil, n, err + } + if n != res.ContentLength { + return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, res.ContentLength) + } + if err := dlRes.Body.Close(); err != nil { + return nil, n, err + } + if err := of.Close(); err != nil { + return nil, n, err + } + pw.print() + + return h.Sum(nil), h.Len(), nil +} + +type progressWriter struct { + done int64 + total int64 + lastPrint time.Time + logf logger.Logf +} + +func (pw *progressWriter) Write(p []byte) (n int, err error) { + pw.done += int64(len(p)) + if time.Since(pw.lastPrint) > 2*time.Second { + pw.print() + } + return len(p), nil +} + +func (pw *progressWriter) print() { + pw.lastPrint = time.Now() + pw.logf("Downloaded %v/%v (%.1f%%)", pw.done, pw.total, float64(pw.done)/float64(pw.total)*100) +} + +func parsePrivateKey(data []byte, typeTag string) (ed25519.PrivateKey, error) { + b, rest := pem.Decode(data) + if b == nil { + return nil, errors.New("failed to decode PEM data") + } + if len(rest) > 0 { + return nil, errors.New("trailing PEM data") + } + if b.Type != typeTag { + return nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) + } + if len(b.Bytes) != ed25519.PrivateKeySize { + return nil, errors.New("private key has incorrect length for an Ed25519 private key") + } + return ed25519.PrivateKey(b.Bytes), nil +} + +// ParseSigningKeyBundle parses the bundle of PEM-encoded public signing keys. +func ParseSigningKeyBundle(bundle []byte) ([]ed25519.PublicKey, error) { + return parsePublicKeyBundle(bundle, pemTypeSigningPublic) +} + +// ParseRootKeyBundle parses the bundle of PEM-encoded public root keys. +func ParseRootKeyBundle(bundle []byte) ([]ed25519.PublicKey, error) { + return parsePublicKeyBundle(bundle, pemTypeRootPublic) +} + +func parsePublicKeyBundle(bundle []byte, typeTag string) ([]ed25519.PublicKey, error) { + var keys []ed25519.PublicKey + for len(bundle) > 0 { + pub, rest, err := parsePublicKey(bundle, typeTag) + if err != nil { + return nil, err + } + keys = append(keys, pub) + bundle = rest + } + if len(keys) == 0 { + return nil, errors.New("no signing keys found in the bundle") + } + return keys, nil +} + +func parseSinglePublicKey(data []byte, typeTag string) (ed25519.PublicKey, error) { + pub, rest, err := parsePublicKey(data, typeTag) + if err != nil { + return nil, err + } + if len(rest) > 0 { + return nil, errors.New("trailing PEM data") + } + return pub, err +} + +func parsePublicKey(data []byte, typeTag string) (pub ed25519.PublicKey, rest []byte, retErr error) { + b, rest := pem.Decode(data) + if b == nil { + return nil, nil, errors.New("failed to decode PEM data") + } + if b.Type != typeTag { + return nil, nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag) + } + if len(b.Bytes) != ed25519.PublicKeySize { + return nil, nil, errors.New("public key has incorrect length for an Ed25519 public key") + } + return ed25519.PublicKey(b.Bytes), rest, nil +} + +// VerifyAny verifies whether sig is valid for msg using any of the keys. +// VerifyAny will panic if any of the keys have the wrong size for Ed25519. +func VerifyAny(keys []ed25519.PublicKey, msg, sig []byte) bool { + for _, k := range keys { + if ed25519consensus.Verify(k, msg, sig) { + return true + } + } + return false +} diff --git a/clientupdate/distsign/roots.go b/clientupdate/distsign/roots.go index df8655797..d5b47b7b6 100644 --- a/clientupdate/distsign/roots.go +++ b/clientupdate/distsign/roots.go @@ -1,54 +1,54 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package distsign - -import ( - "crypto/ed25519" - "embed" - "errors" - "fmt" - "path" - "path/filepath" - "sync" -) - -//go:embed roots -var rootsFS embed.FS - -var roots = sync.OnceValue(func() []ed25519.PublicKey { - roots, err := parseRoots() - if err != nil { - panic(err) - } - return roots -}) - -func parseRoots() ([]ed25519.PublicKey, error) { - files, err := rootsFS.ReadDir("roots") - if err != nil { - return nil, err - } - var keys []ed25519.PublicKey - for _, f := range files { - if !f.Type().IsRegular() { - continue - } - if filepath.Ext(f.Name()) != ".pem" { - continue - } - raw, err := rootsFS.ReadFile(path.Join("roots", f.Name())) - if err != nil { - return nil, err - } - key, err := parseSinglePublicKey(raw, pemTypeRootPublic) - if err != nil { - return nil, fmt.Errorf("parsing root key %q: %w", f.Name(), err) - } - keys = append(keys, key) - } - if len(keys) == 0 { - return nil, errors.New("no embedded root keys, please check clientupdate/distsign/roots/") - } - return keys, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package distsign + +import ( + "crypto/ed25519" + "embed" + "errors" + "fmt" + "path" + "path/filepath" + "sync" +) + +//go:embed roots +var rootsFS embed.FS + +var roots = sync.OnceValue(func() []ed25519.PublicKey { + roots, err := parseRoots() + if err != nil { + panic(err) + } + return roots +}) + +func parseRoots() ([]ed25519.PublicKey, error) { + files, err := rootsFS.ReadDir("roots") + if err != nil { + return nil, err + } + var keys []ed25519.PublicKey + for _, f := range files { + if !f.Type().IsRegular() { + continue + } + if filepath.Ext(f.Name()) != ".pem" { + continue + } + raw, err := rootsFS.ReadFile(path.Join("roots", f.Name())) + if err != nil { + return nil, err + } + key, err := parseSinglePublicKey(raw, pemTypeRootPublic) + if err != nil { + return nil, fmt.Errorf("parsing root key %q: %w", f.Name(), err) + } + keys = append(keys, key) + } + if len(keys) == 0 { + return nil, errors.New("no embedded root keys, please check clientupdate/distsign/roots/") + } + return keys, nil +} diff --git a/clientupdate/distsign/roots/crawshaw-root.pem b/clientupdate/distsign/roots/crawshaw-root.pem index 897a38295..f80b9aec7 100755 --- a/clientupdate/distsign/roots/crawshaw-root.pem +++ b/clientupdate/distsign/roots/crawshaw-root.pem @@ -1,3 +1,3 @@ ------BEGIN ROOT PUBLIC KEY----- -Psrabv2YNiEDhPlnLVSMtB5EKACm7zxvKxfvYD4i7X8= ------END ROOT PUBLIC KEY----- +-----BEGIN ROOT PUBLIC KEY----- +Psrabv2YNiEDhPlnLVSMtB5EKACm7zxvKxfvYD4i7X8= +-----END ROOT PUBLIC KEY----- diff --git a/clientupdate/distsign/roots/distsign-prod-root-1-pub.pem b/clientupdate/distsign/roots/distsign-prod-root-1-pub.pem index e2f937ed3..d5d6516ab 100644 --- a/clientupdate/distsign/roots/distsign-prod-root-1-pub.pem +++ b/clientupdate/distsign/roots/distsign-prod-root-1-pub.pem @@ -1,3 +1,3 @@ ------BEGIN ROOT PUBLIC KEY----- -ZjjKhUHBtLNRSO1dhOTjrXJGJ8lDe1594WM2XDuheVQ= ------END ROOT PUBLIC KEY----- +-----BEGIN ROOT PUBLIC KEY----- +ZjjKhUHBtLNRSO1dhOTjrXJGJ8lDe1594WM2XDuheVQ= +-----END ROOT PUBLIC KEY----- diff --git a/clientupdate/distsign/roots_test.go b/clientupdate/distsign/roots_test.go index ae0dfbc22..7a9452953 100644 --- a/clientupdate/distsign/roots_test.go +++ b/clientupdate/distsign/roots_test.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package distsign - -import "testing" - -func TestParseRoots(t *testing.T) { - roots, err := parseRoots() - if err != nil { - t.Fatal(err) - } - if len(roots) == 0 { - t.Error("parseRoots returned no root keys") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package distsign + +import "testing" + +func TestParseRoots(t *testing.T) { + roots, err := parseRoots() + if err != nil { + t.Fatal(err) + } + if len(roots) == 0 { + t.Error("parseRoots returned no root keys") + } +} diff --git a/cmd/addlicense/main.go b/cmd/addlicense/main.go index 58ef7a471..a8fd9dd4a 100644 --- a/cmd/addlicense/main.go +++ b/cmd/addlicense/main.go @@ -1,73 +1,73 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Program addlicense adds a license header to a file. -// It is intended for use with 'go generate', -// so it has a slightly weird usage. -package main - -import ( - "flag" - "fmt" - "os" - "os/exec" -) - -var ( - file = flag.String("file", "", "file to modify") -) - -func usage() { - fmt.Fprintf(os.Stderr, ` -usage: addlicense -file FILE -`[1:]) - - flag.PrintDefaults() - fmt.Fprintf(os.Stderr, ` -addlicense adds a Tailscale license to the beginning of file. - -It is intended for use with 'go generate', so it also runs a subcommand, -which presumably creates the file. - -Sample usage: - -addlicense -file pull_strings.go stringer -type=pull -`[1:]) - os.Exit(2) -} - -func main() { - flag.Usage = usage - flag.Parse() - if len(flag.Args()) == 0 { - flag.Usage() - } - cmd := exec.Command(flag.Arg(0), flag.Args()[1:]...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - err := cmd.Run() - check(err) - b, err := os.ReadFile(*file) - check(err) - f, err := os.OpenFile(*file, os.O_TRUNC|os.O_WRONLY, 0644) - check(err) - _, err = fmt.Fprint(f, license) - check(err) - _, err = f.Write(b) - check(err) - err = f.Close() - check(err) -} - -func check(err error) { - if err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } -} - -var license = ` -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -`[1:] +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Program addlicense adds a license header to a file. +// It is intended for use with 'go generate', +// so it has a slightly weird usage. +package main + +import ( + "flag" + "fmt" + "os" + "os/exec" +) + +var ( + file = flag.String("file", "", "file to modify") +) + +func usage() { + fmt.Fprintf(os.Stderr, ` +usage: addlicense -file FILE +`[1:]) + + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, ` +addlicense adds a Tailscale license to the beginning of file. + +It is intended for use with 'go generate', so it also runs a subcommand, +which presumably creates the file. + +Sample usage: + +addlicense -file pull_strings.go stringer -type=pull +`[1:]) + os.Exit(2) +} + +func main() { + flag.Usage = usage + flag.Parse() + if len(flag.Args()) == 0 { + flag.Usage() + } + cmd := exec.Command(flag.Arg(0), flag.Args()[1:]...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Run() + check(err) + b, err := os.ReadFile(*file) + check(err) + f, err := os.OpenFile(*file, os.O_TRUNC|os.O_WRONLY, 0644) + check(err) + _, err = fmt.Fprint(f, license) + check(err) + _, err = f.Write(b) + check(err) + err = f.Close() + check(err) +} + +func check(err error) { + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +var license = ` +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +`[1:] diff --git a/cmd/cloner/cloner_test.go b/cmd/cloner/cloner_test.go index 83d33ab0e..d8d5df3cb 100644 --- a/cmd/cloner/cloner_test.go +++ b/cmd/cloner/cloner_test.go @@ -1,60 +1,60 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -package main - -import ( - "reflect" - "testing" - - "tailscale.com/cmd/cloner/clonerex" -) - -func TestSliceContainer(t *testing.T) { - num := 5 - examples := []struct { - name string - in *clonerex.SliceContainer - }{ - { - name: "nil", - in: nil, - }, - { - name: "zero", - in: &clonerex.SliceContainer{}, - }, - { - name: "empty", - in: &clonerex.SliceContainer{ - Slice: []*int{}, - }, - }, - { - name: "nils", - in: &clonerex.SliceContainer{ - Slice: []*int{nil, nil, nil, nil, nil}, - }, - }, - { - name: "one", - in: &clonerex.SliceContainer{ - Slice: []*int{&num}, - }, - }, - { - name: "several", - in: &clonerex.SliceContainer{ - Slice: []*int{&num, &num, &num, &num, &num}, - }, - }, - } - - for _, ex := range examples { - t.Run(ex.name, func(t *testing.T) { - out := ex.in.Clone() - if !reflect.DeepEqual(ex.in, out) { - t.Errorf("Clone() = %v, want %v", out, ex.in) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +package main + +import ( + "reflect" + "testing" + + "tailscale.com/cmd/cloner/clonerex" +) + +func TestSliceContainer(t *testing.T) { + num := 5 + examples := []struct { + name string + in *clonerex.SliceContainer + }{ + { + name: "nil", + in: nil, + }, + { + name: "zero", + in: &clonerex.SliceContainer{}, + }, + { + name: "empty", + in: &clonerex.SliceContainer{ + Slice: []*int{}, + }, + }, + { + name: "nils", + in: &clonerex.SliceContainer{ + Slice: []*int{nil, nil, nil, nil, nil}, + }, + }, + { + name: "one", + in: &clonerex.SliceContainer{ + Slice: []*int{&num}, + }, + }, + { + name: "several", + in: &clonerex.SliceContainer{ + Slice: []*int{&num, &num, &num, &num, &num}, + }, + }, + } + + for _, ex := range examples { + t.Run(ex.name, func(t *testing.T) { + out := ex.in.Clone() + if !reflect.DeepEqual(ex.in, out) { + t.Errorf("Clone() = %v, want %v", out, ex.in) + } + }) + } +} diff --git a/cmd/containerboot/test_tailscale.sh b/cmd/containerboot/test_tailscale.sh index dd56adf04..1fa10abb1 100644 --- a/cmd/containerboot/test_tailscale.sh +++ b/cmd/containerboot/test_tailscale.sh @@ -1,8 +1,8 @@ -#!/usr/bin/env bash -# -# This is a fake tailscale CLI (and also iptables and ip6tables) that -# records its arguments and exits successfully. -# -# It is used by main_test.go to test the behavior of containerboot. - -echo $0 $@ >>$TS_TEST_RECORD_ARGS +#!/usr/bin/env bash +# +# This is a fake tailscale CLI (and also iptables and ip6tables) that +# records its arguments and exits successfully. +# +# It is used by main_test.go to test the behavior of containerboot. + +echo $0 $@ >>$TS_TEST_RECORD_ARGS diff --git a/cmd/containerboot/test_tailscaled.sh b/cmd/containerboot/test_tailscaled.sh index b7404a0a9..335e2cb0d 100644 --- a/cmd/containerboot/test_tailscaled.sh +++ b/cmd/containerboot/test_tailscaled.sh @@ -1,38 +1,38 @@ -#!/usr/bin/env bash -# -# This is a fake tailscale daemon that records its arguments, symlinks a -# fake LocalAPI socket into place, and does nothing until terminated. -# -# It is used by main_test.go to test the behavior of containerboot. - -set -eu - -echo $0 $@ >>$TS_TEST_RECORD_ARGS - -socket="" -while [[ $# -gt 0 ]]; do - case $1 in - --socket=*) - socket="${1#--socket=}" - shift - ;; - --socket) - shift - socket="$1" - shift - ;; - *) - shift - ;; - esac -done - -if [[ -z "$socket" ]]; then - echo "didn't find socket path in args" - exit 1 -fi - -ln -s "$TS_TEST_SOCKET" "$socket" -trap 'rm -f "$socket"' EXIT - -while sleep 10; do :; done +#!/usr/bin/env bash +# +# This is a fake tailscale daemon that records its arguments, symlinks a +# fake LocalAPI socket into place, and does nothing until terminated. +# +# It is used by main_test.go to test the behavior of containerboot. + +set -eu + +echo $0 $@ >>$TS_TEST_RECORD_ARGS + +socket="" +while [[ $# -gt 0 ]]; do + case $1 in + --socket=*) + socket="${1#--socket=}" + shift + ;; + --socket) + shift + socket="$1" + shift + ;; + *) + shift + ;; + esac +done + +if [[ -z "$socket" ]]; then + echo "didn't find socket path in args" + exit 1 +fi + +ln -s "$TS_TEST_SOCKET" "$socket" +trap 'rm -f "$socket"' EXIT + +while sleep 10; do :; done diff --git a/cmd/get-authkey/.gitignore b/cmd/get-authkey/.gitignore index e00856fa1..3f9c9fb90 100644 --- a/cmd/get-authkey/.gitignore +++ b/cmd/get-authkey/.gitignore @@ -1 +1 @@ -get-authkey +get-authkey diff --git a/cmd/gitops-pusher/.gitignore b/cmd/gitops-pusher/.gitignore index eeed6e4bf..504452249 100644 --- a/cmd/gitops-pusher/.gitignore +++ b/cmd/gitops-pusher/.gitignore @@ -1 +1 @@ -version-cache.json +version-cache.json diff --git a/cmd/gitops-pusher/README.md b/cmd/gitops-pusher/README.md index b08125397..9f77ea970 100644 --- a/cmd/gitops-pusher/README.md +++ b/cmd/gitops-pusher/README.md @@ -1,48 +1,48 @@ -# gitops-pusher - -This is a small tool to help people achieve a -[GitOps](https://about.gitlab.com/topics/gitops/) workflow with Tailscale ACL -changes. This tool is intended to be used in a CI flow that looks like this: - -```yaml -name: Tailscale ACL syncing - -on: - push: - branches: [ "main" ] - pull_request: - branches: [ "main" ] - -jobs: - acls: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - - name: Setup Go environment - uses: actions/setup-go@v3.2.0 - - - name: Install gitops-pusher - run: go install tailscale.com/cmd/gitops-pusher@latest - - - name: Deploy ACL - if: github.event_name == 'push' - env: - TS_API_KEY: ${{ secrets.TS_API_KEY }} - TS_TAILNET: ${{ secrets.TS_TAILNET }} - run: | - ~/go/bin/gitops-pusher --policy-file ./policy.hujson apply - - - name: ACL tests - if: github.event_name == 'pull_request' - env: - TS_API_KEY: ${{ secrets.TS_API_KEY }} - TS_TAILNET: ${{ secrets.TS_TAILNET }} - run: | - ~/go/bin/gitops-pusher --policy-file ./policy.hujson test -``` - -Change the value of the `--policy-file` flag to point to the policy file on -disk. Policy files should be in [HuJSON](https://github.com/tailscale/hujson) -format. +# gitops-pusher + +This is a small tool to help people achieve a +[GitOps](https://about.gitlab.com/topics/gitops/) workflow with Tailscale ACL +changes. This tool is intended to be used in a CI flow that looks like this: + +```yaml +name: Tailscale ACL syncing + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + acls: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Setup Go environment + uses: actions/setup-go@v3.2.0 + + - name: Install gitops-pusher + run: go install tailscale.com/cmd/gitops-pusher@latest + + - name: Deploy ACL + if: github.event_name == 'push' + env: + TS_API_KEY: ${{ secrets.TS_API_KEY }} + TS_TAILNET: ${{ secrets.TS_TAILNET }} + run: | + ~/go/bin/gitops-pusher --policy-file ./policy.hujson apply + + - name: ACL tests + if: github.event_name == 'pull_request' + env: + TS_API_KEY: ${{ secrets.TS_API_KEY }} + TS_TAILNET: ${{ secrets.TS_TAILNET }} + run: | + ~/go/bin/gitops-pusher --policy-file ./policy.hujson test +``` + +Change the value of the `--policy-file` flag to point to the policy file on +disk. Policy files should be in [HuJSON](https://github.com/tailscale/hujson) +format. diff --git a/cmd/gitops-pusher/cache.go b/cmd/gitops-pusher/cache.go index 89225e6f8..6792e5e63 100644 --- a/cmd/gitops-pusher/cache.go +++ b/cmd/gitops-pusher/cache.go @@ -1,66 +1,66 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "encoding/json" - "os" -) - -// Cache contains cached information about the last time this tool was run. -// -// This is serialized to a JSON file that should NOT be checked into git. -// It should be managed with either CI cache tools or stored locally somehow. The -// exact mechanism is irrelevant as long as it is consistent. -// -// This allows gitops-pusher to detect external ACL changes. I'm not sure what to -// call this problem, so I've been calling it the "three version problem" in my -// notes. The basic problem is that at any given time we only have two versions -// of the ACL file at any given point. In order to check if there has been -// tampering of the ACL files in the admin panel, we need to have a _third_ version -// to compare against. -// -// In this case I am not storing the old ACL entirely (though that could be a -// reasonable thing to add in the future), but only its sha256sum. This allows -// us to detect if the shasum in control matches the shasum we expect, and if that -// expectation fails, then we can react accordingly. -type Cache struct { - PrevETag string // Stores the previous ETag of the ACL to allow -} - -// Save persists the cache to a given file. -func (c *Cache) Save(fname string) error { - os.Remove(fname) - fout, err := os.Create(fname) - if err != nil { - return err - } - defer fout.Close() - - return json.NewEncoder(fout).Encode(c) -} - -// LoadCache loads the cache from a given file. -func LoadCache(fname string) (*Cache, error) { - var result Cache - - fin, err := os.Open(fname) - if err != nil { - return nil, err - } - defer fin.Close() - - err = json.NewDecoder(fin).Decode(&result) - if err != nil { - return nil, err - } - - return &result, nil -} - -// Shuck removes the first and last character of a string, analogous to -// shucking off the husk of an ear of corn. -func Shuck(s string) string { - return s[1 : len(s)-1] -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "encoding/json" + "os" +) + +// Cache contains cached information about the last time this tool was run. +// +// This is serialized to a JSON file that should NOT be checked into git. +// It should be managed with either CI cache tools or stored locally somehow. The +// exact mechanism is irrelevant as long as it is consistent. +// +// This allows gitops-pusher to detect external ACL changes. I'm not sure what to +// call this problem, so I've been calling it the "three version problem" in my +// notes. The basic problem is that at any given time we only have two versions +// of the ACL file at any given point. In order to check if there has been +// tampering of the ACL files in the admin panel, we need to have a _third_ version +// to compare against. +// +// In this case I am not storing the old ACL entirely (though that could be a +// reasonable thing to add in the future), but only its sha256sum. This allows +// us to detect if the shasum in control matches the shasum we expect, and if that +// expectation fails, then we can react accordingly. +type Cache struct { + PrevETag string // Stores the previous ETag of the ACL to allow +} + +// Save persists the cache to a given file. +func (c *Cache) Save(fname string) error { + os.Remove(fname) + fout, err := os.Create(fname) + if err != nil { + return err + } + defer fout.Close() + + return json.NewEncoder(fout).Encode(c) +} + +// LoadCache loads the cache from a given file. +func LoadCache(fname string) (*Cache, error) { + var result Cache + + fin, err := os.Open(fname) + if err != nil { + return nil, err + } + defer fin.Close() + + err = json.NewDecoder(fin).Decode(&result) + if err != nil { + return nil, err + } + + return &result, nil +} + +// Shuck removes the first and last character of a string, analogous to +// shucking off the husk of an ear of corn. +func Shuck(s string) string { + return s[1 : len(s)-1] +} diff --git a/cmd/gitops-pusher/gitops-pusher_test.go b/cmd/gitops-pusher/gitops-pusher_test.go index 1beb049c6..b050761d9 100644 --- a/cmd/gitops-pusher/gitops-pusher_test.go +++ b/cmd/gitops-pusher/gitops-pusher_test.go @@ -1,55 +1,55 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -package main - -import ( - "encoding/json" - "strings" - "testing" - - "tailscale.com/client/tailscale" -) - -func TestEmbeddedTypeUnmarshal(t *testing.T) { - var gitopsErr ACLGitopsTestError - gitopsErr.Message = "gitops response error" - gitopsErr.Data = []tailscale.ACLTestFailureSummary{ - { - User: "GitopsError", - Errors: []string{"this was initially created as a gitops error"}, - }, - } - - var aclTestErr tailscale.ACLTestError - aclTestErr.Message = "native ACL response error" - aclTestErr.Data = []tailscale.ACLTestFailureSummary{ - { - User: "ACLError", - Errors: []string{"this was initially created as an ACL error"}, - }, - } - - t.Run("unmarshal gitops type from acl type", func(t *testing.T) { - b, _ := json.Marshal(aclTestErr) - var e ACLGitopsTestError - err := json.Unmarshal(b, &e) - if err != nil { - t.Fatal(err) - } - if !strings.Contains(e.Error(), "For user ACLError") { // the gitops error prints out the user, the acl error doesn't - t.Fatalf("user heading for 'ACLError' not found in gitops error: %v", e.Error()) - } - }) - t.Run("unmarshal acl type from gitops type", func(t *testing.T) { - b, _ := json.Marshal(gitopsErr) - var e tailscale.ACLTestError - err := json.Unmarshal(b, &e) - if err != nil { - t.Fatal(err) - } - expectedErr := `Status: 0, Message: "gitops response error", Data: [{User:GitopsError Errors:[this was initially created as a gitops error] Warnings:[]}]` - if e.Error() != expectedErr { - t.Fatalf("got %v\n, expected %v", e.Error(), expectedErr) - } - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +package main + +import ( + "encoding/json" + "strings" + "testing" + + "tailscale.com/client/tailscale" +) + +func TestEmbeddedTypeUnmarshal(t *testing.T) { + var gitopsErr ACLGitopsTestError + gitopsErr.Message = "gitops response error" + gitopsErr.Data = []tailscale.ACLTestFailureSummary{ + { + User: "GitopsError", + Errors: []string{"this was initially created as a gitops error"}, + }, + } + + var aclTestErr tailscale.ACLTestError + aclTestErr.Message = "native ACL response error" + aclTestErr.Data = []tailscale.ACLTestFailureSummary{ + { + User: "ACLError", + Errors: []string{"this was initially created as an ACL error"}, + }, + } + + t.Run("unmarshal gitops type from acl type", func(t *testing.T) { + b, _ := json.Marshal(aclTestErr) + var e ACLGitopsTestError + err := json.Unmarshal(b, &e) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(e.Error(), "For user ACLError") { // the gitops error prints out the user, the acl error doesn't + t.Fatalf("user heading for 'ACLError' not found in gitops error: %v", e.Error()) + } + }) + t.Run("unmarshal acl type from gitops type", func(t *testing.T) { + b, _ := json.Marshal(gitopsErr) + var e tailscale.ACLTestError + err := json.Unmarshal(b, &e) + if err != nil { + t.Fatal(err) + } + expectedErr := `Status: 0, Message: "gitops response error", Data: [{User:GitopsError Errors:[this was initially created as a gitops error] Warnings:[]}]` + if e.Error() != expectedErr { + t.Fatalf("got %v\n, expected %v", e.Error(), expectedErr) + } + }) +} diff --git a/cmd/k8s-operator/deploy/chart/.helmignore b/cmd/k8s-operator/deploy/chart/.helmignore index f82e96d46..0e8a0eb36 100644 --- a/cmd/k8s-operator/deploy/chart/.helmignore +++ b/cmd/k8s-operator/deploy/chart/.helmignore @@ -1,23 +1,23 @@ -# Patterns to ignore when building packages. -# This supports shell glob matching, relative path matching, and -# negation (prefixed with !). Only one pattern per line. -.DS_Store -# Common VCS dirs -.git/ -.gitignore -.bzr/ -.bzrignore -.hg/ -.hgignore -.svn/ -# Common backup files -*.swp -*.bak -*.tmp -*.orig -*~ -# Various IDEs -.project -.idea/ -*.tmproj -.vscode/ +# Patterns to ignore when building packages. +# This supports shell glob matching, relative path matching, and +# negation (prefixed with !). Only one pattern per line. +.DS_Store +# Common VCS dirs +.git/ +.gitignore +.bzr/ +.bzrignore +.hg/ +.hgignore +.svn/ +# Common backup files +*.swp +*.bak +*.tmp +*.orig +*~ +# Various IDEs +.project +.idea/ +*.tmproj +.vscode/ diff --git a/cmd/k8s-operator/deploy/chart/Chart.yaml b/cmd/k8s-operator/deploy/chart/Chart.yaml index 472850c41..363d87d15 100644 --- a/cmd/k8s-operator/deploy/chart/Chart.yaml +++ b/cmd/k8s-operator/deploy/chart/Chart.yaml @@ -1,29 +1,29 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -apiVersion: v2 -name: tailscale-operator -description: A Helm chart for Tailscale Kubernetes operator -home: https://github.com/tailscale/tailscale - -keywords: - - "tailscale" - - "vpn" - - "ingress" - - "egress" - - "wireguard" - -sources: -- https://github.com/tailscale/tailscale - -type: application - -maintainers: - - name: tailscale-maintainers - url: https://tailscale.com/ - -# version will be set to Tailscale repo tag (without 'v') at release time. -version: 0.1.0 - -# appVersion will be set to Tailscale repo tag at release time. -appVersion: "unstable" +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +apiVersion: v2 +name: tailscale-operator +description: A Helm chart for Tailscale Kubernetes operator +home: https://github.com/tailscale/tailscale + +keywords: + - "tailscale" + - "vpn" + - "ingress" + - "egress" + - "wireguard" + +sources: +- https://github.com/tailscale/tailscale + +type: application + +maintainers: + - name: tailscale-maintainers + url: https://tailscale.com/ + +# version will be set to Tailscale repo tag (without 'v') at release time. +version: 0.1.0 + +# appVersion will be set to Tailscale repo tag at release time. +appVersion: "unstable" diff --git a/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml b/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml index 488c87d8a..072ecf6d2 100644 --- a/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml @@ -1,26 +1,26 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -{{ if eq .Values.apiServerProxyConfig.mode "true" }} -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - name: tailscale-auth-proxy -rules: -- apiGroups: [""] - resources: ["users", "groups"] - verbs: ["impersonate"] ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRoleBinding -metadata: - name: tailscale-auth-proxy -subjects: -- kind: ServiceAccount - name: operator - namespace: {{ .Release.Namespace }} -roleRef: - kind: ClusterRole - name: tailscale-auth-proxy - apiGroup: rbac.authorization.k8s.io -{{ end }} +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +{{ if eq .Values.apiServerProxyConfig.mode "true" }} +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: tailscale-auth-proxy +rules: +- apiGroups: [""] + resources: ["users", "groups"] + verbs: ["impersonate"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: tailscale-auth-proxy +subjects: +- kind: ServiceAccount + name: operator + namespace: {{ .Release.Namespace }} +roleRef: + kind: ClusterRole + name: tailscale-auth-proxy + apiGroup: rbac.authorization.k8s.io +{{ end }} diff --git a/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml b/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml index bde64b7f6..b44fde0a1 100644 --- a/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml @@ -1,13 +1,13 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -{{ if and .Values.oauth .Values.oauth.clientId -}} -apiVersion: v1 -kind: Secret -metadata: - name: operator-oauth - namespace: {{ .Release.Namespace }} -stringData: - client_id: {{ .Values.oauth.clientId }} - client_secret: {{ .Values.oauth.clientSecret }} -{{- end -}} +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +{{ if and .Values.oauth .Values.oauth.clientId -}} +apiVersion: v1 +kind: Secret +metadata: + name: operator-oauth + namespace: {{ .Release.Namespace }} +stringData: + client_id: {{ .Values.oauth.clientId }} + client_secret: {{ .Values.oauth.clientSecret }} +{{- end -}} diff --git a/cmd/k8s-operator/deploy/manifests/authproxy-rbac.yaml b/cmd/k8s-operator/deploy/manifests/authproxy-rbac.yaml index d957260eb..ddbdda32e 100644 --- a/cmd/k8s-operator/deploy/manifests/authproxy-rbac.yaml +++ b/cmd/k8s-operator/deploy/manifests/authproxy-rbac.yaml @@ -1,24 +1,24 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - name: tailscale-auth-proxy -rules: -- apiGroups: [""] - resources: ["users", "groups"] - verbs: ["impersonate"] ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRoleBinding -metadata: - name: tailscale-auth-proxy -subjects: -- kind: ServiceAccount - name: operator - namespace: tailscale -roleRef: - kind: ClusterRole - name: tailscale-auth-proxy +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: tailscale-auth-proxy +rules: +- apiGroups: [""] + resources: ["users", "groups"] + verbs: ["impersonate"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: tailscale-auth-proxy +subjects: +- kind: ServiceAccount + name: operator + namespace: tailscale +roleRef: + kind: ClusterRole + name: tailscale-auth-proxy apiGroup: rbac.authorization.k8s.io \ No newline at end of file diff --git a/cmd/mkmanifest/main.go b/cmd/mkmanifest/main.go index 22cd15026..fb3c729f1 100644 --- a/cmd/mkmanifest/main.go +++ b/cmd/mkmanifest/main.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The mkmanifest command is a simple helper utility to create a '.syso' file -// that contains a Windows manifest file. -package main - -import ( - "log" - "os" - - "github.com/tc-hib/winres" -) - -func main() { - if len(os.Args) != 4 { - log.Fatalf("usage: %s arch manifest.xml output.syso", os.Args[0]) - } - - arch := winres.Arch(os.Args[1]) - switch arch { - case winres.ArchAMD64, winres.ArchARM64, winres.ArchI386: - default: - log.Fatalf("unsupported arch: %s", arch) - } - - manifest, err := os.ReadFile(os.Args[2]) - if err != nil { - log.Fatalf("error reading manifest file %q: %v", os.Args[2], err) - } - - out := os.Args[3] - - // Start by creating an empty resource set - rs := winres.ResourceSet{} - - // Add resources - rs.Set(winres.RT_MANIFEST, winres.ID(1), 0, manifest) - - // Compile to a COFF object file - f, err := os.Create(out) - if err != nil { - log.Fatalf("error creating output file %q: %v", out, err) - } - if err := rs.WriteObject(f, arch); err != nil { - log.Fatalf("error writing object: %v", err) - } - if err := f.Close(); err != nil { - log.Fatalf("error writing output file %q: %v", out, err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The mkmanifest command is a simple helper utility to create a '.syso' file +// that contains a Windows manifest file. +package main + +import ( + "log" + "os" + + "github.com/tc-hib/winres" +) + +func main() { + if len(os.Args) != 4 { + log.Fatalf("usage: %s arch manifest.xml output.syso", os.Args[0]) + } + + arch := winres.Arch(os.Args[1]) + switch arch { + case winres.ArchAMD64, winres.ArchARM64, winres.ArchI386: + default: + log.Fatalf("unsupported arch: %s", arch) + } + + manifest, err := os.ReadFile(os.Args[2]) + if err != nil { + log.Fatalf("error reading manifest file %q: %v", os.Args[2], err) + } + + out := os.Args[3] + + // Start by creating an empty resource set + rs := winres.ResourceSet{} + + // Add resources + rs.Set(winres.RT_MANIFEST, winres.ID(1), 0, manifest) + + // Compile to a COFF object file + f, err := os.Create(out) + if err != nil { + log.Fatalf("error creating output file %q: %v", out, err) + } + if err := rs.WriteObject(f, arch); err != nil { + log.Fatalf("error writing object: %v", err) + } + if err := f.Close(); err != nil { + log.Fatalf("error writing output file %q: %v", out, err) + } +} diff --git a/cmd/mkpkg/main.go b/cmd/mkpkg/main.go index e942c0162..5e26b07f8 100644 --- a/cmd/mkpkg/main.go +++ b/cmd/mkpkg/main.go @@ -1,134 +1,134 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// mkpkg builds the Tailscale rpm and deb packages. -package main - -import ( - "flag" - "fmt" - "log" - "os" - "strings" - - "github.com/goreleaser/nfpm/v2" - _ "github.com/goreleaser/nfpm/v2/deb" - "github.com/goreleaser/nfpm/v2/files" - _ "github.com/goreleaser/nfpm/v2/rpm" -) - -// parseFiles parses a comma-separated list of colon-separated pairs -// into files.Contents format. -func parseFiles(s string, typ string) (files.Contents, error) { - if len(s) == 0 { - return nil, nil - } - var contents files.Contents - for _, f := range strings.Split(s, ",") { - fs := strings.Split(f, ":") - if len(fs) != 2 { - return nil, fmt.Errorf("unparseable file field %q", f) - } - contents = append(contents, &files.Content{Type: files.TypeFile, Source: fs[0], Destination: fs[1]}) - } - return contents, nil -} - -func parseEmptyDirs(s string) files.Contents { - // strings.Split("", ",") would return []string{""}, which is not suitable: - // this would create an empty dir record with path "", breaking the package - if s == "" { - return nil - } - var contents files.Contents - for _, d := range strings.Split(s, ",") { - contents = append(contents, &files.Content{Type: files.TypeDir, Destination: d}) - } - return contents -} - -func main() { - out := flag.String("out", "", "output file to write") - name := flag.String("name", "tailscale", "package name") - description := flag.String("description", "The easiest, most secure, cross platform way to use WireGuard + oauth2 + 2FA/SSO", "package description") - goarch := flag.String("arch", "amd64", "GOARCH this package is for") - pkgType := flag.String("type", "deb", "type of package to build (deb or rpm)") - regularFiles := flag.String("files", "", "comma-separated list of files in src:dst form") - configFiles := flag.String("configs", "", "like --files, but for files marked as user-editable config files") - emptyDirs := flag.String("emptydirs", "", "comma-separated list of empty directories") - version := flag.String("version", "0.0.0", "version of the package") - postinst := flag.String("postinst", "", "debian postinst script path") - prerm := flag.String("prerm", "", "debian prerm script path") - postrm := flag.String("postrm", "", "debian postrm script path") - replaces := flag.String("replaces", "", "package which this package replaces, if any") - depends := flag.String("depends", "", "comma-separated list of packages this package depends on") - recommends := flag.String("recommends", "", "comma-separated list of packages this package recommends") - flag.Parse() - - filesList, err := parseFiles(*regularFiles, files.TypeFile) - if err != nil { - log.Fatalf("Parsing --files: %v", err) - } - configsList, err := parseFiles(*configFiles, files.TypeConfig) - if err != nil { - log.Fatalf("Parsing --configs: %v", err) - } - emptyDirList := parseEmptyDirs(*emptyDirs) - contents := append(filesList, append(configsList, emptyDirList...)...) - contents, err = files.PrepareForPackager(contents, 0, *pkgType, false) - if err != nil { - log.Fatalf("Building package contents: %v", err) - } - info := nfpm.WithDefaults(&nfpm.Info{ - Name: *name, - Arch: *goarch, - Platform: "linux", - Version: *version, - Maintainer: "Tailscale Inc ", - Description: *description, - Homepage: "https://www.tailscale.com", - License: "MIT", - Overridables: nfpm.Overridables{ - Contents: contents, - Scripts: nfpm.Scripts{ - PostInstall: *postinst, - PreRemove: *prerm, - PostRemove: *postrm, - }, - }, - }) - - if len(*depends) != 0 { - info.Overridables.Depends = strings.Split(*depends, ",") - } - if len(*recommends) != 0 { - info.Overridables.Recommends = strings.Split(*recommends, ",") - } - if *replaces != "" { - info.Overridables.Replaces = []string{*replaces} - info.Overridables.Conflicts = []string{*replaces} - } - - switch *pkgType { - case "deb": - info.Section = "net" - info.Priority = "extra" - case "rpm": - info.Overridables.RPM.Group = "Network" - } - - pkg, err := nfpm.Get(*pkgType) - if err != nil { - log.Fatalf("Getting packager for %q: %v", *pkgType, err) - } - - f, err := os.Create(*out) - if err != nil { - log.Fatalf("Creating output file %q: %v", *out, err) - } - defer f.Close() - - if err := pkg.Package(info, f); err != nil { - log.Fatalf("Creating package %q: %v", *out, err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// mkpkg builds the Tailscale rpm and deb packages. +package main + +import ( + "flag" + "fmt" + "log" + "os" + "strings" + + "github.com/goreleaser/nfpm/v2" + _ "github.com/goreleaser/nfpm/v2/deb" + "github.com/goreleaser/nfpm/v2/files" + _ "github.com/goreleaser/nfpm/v2/rpm" +) + +// parseFiles parses a comma-separated list of colon-separated pairs +// into files.Contents format. +func parseFiles(s string, typ string) (files.Contents, error) { + if len(s) == 0 { + return nil, nil + } + var contents files.Contents + for _, f := range strings.Split(s, ",") { + fs := strings.Split(f, ":") + if len(fs) != 2 { + return nil, fmt.Errorf("unparseable file field %q", f) + } + contents = append(contents, &files.Content{Type: files.TypeFile, Source: fs[0], Destination: fs[1]}) + } + return contents, nil +} + +func parseEmptyDirs(s string) files.Contents { + // strings.Split("", ",") would return []string{""}, which is not suitable: + // this would create an empty dir record with path "", breaking the package + if s == "" { + return nil + } + var contents files.Contents + for _, d := range strings.Split(s, ",") { + contents = append(contents, &files.Content{Type: files.TypeDir, Destination: d}) + } + return contents +} + +func main() { + out := flag.String("out", "", "output file to write") + name := flag.String("name", "tailscale", "package name") + description := flag.String("description", "The easiest, most secure, cross platform way to use WireGuard + oauth2 + 2FA/SSO", "package description") + goarch := flag.String("arch", "amd64", "GOARCH this package is for") + pkgType := flag.String("type", "deb", "type of package to build (deb or rpm)") + regularFiles := flag.String("files", "", "comma-separated list of files in src:dst form") + configFiles := flag.String("configs", "", "like --files, but for files marked as user-editable config files") + emptyDirs := flag.String("emptydirs", "", "comma-separated list of empty directories") + version := flag.String("version", "0.0.0", "version of the package") + postinst := flag.String("postinst", "", "debian postinst script path") + prerm := flag.String("prerm", "", "debian prerm script path") + postrm := flag.String("postrm", "", "debian postrm script path") + replaces := flag.String("replaces", "", "package which this package replaces, if any") + depends := flag.String("depends", "", "comma-separated list of packages this package depends on") + recommends := flag.String("recommends", "", "comma-separated list of packages this package recommends") + flag.Parse() + + filesList, err := parseFiles(*regularFiles, files.TypeFile) + if err != nil { + log.Fatalf("Parsing --files: %v", err) + } + configsList, err := parseFiles(*configFiles, files.TypeConfig) + if err != nil { + log.Fatalf("Parsing --configs: %v", err) + } + emptyDirList := parseEmptyDirs(*emptyDirs) + contents := append(filesList, append(configsList, emptyDirList...)...) + contents, err = files.PrepareForPackager(contents, 0, *pkgType, false) + if err != nil { + log.Fatalf("Building package contents: %v", err) + } + info := nfpm.WithDefaults(&nfpm.Info{ + Name: *name, + Arch: *goarch, + Platform: "linux", + Version: *version, + Maintainer: "Tailscale Inc ", + Description: *description, + Homepage: "https://www.tailscale.com", + License: "MIT", + Overridables: nfpm.Overridables{ + Contents: contents, + Scripts: nfpm.Scripts{ + PostInstall: *postinst, + PreRemove: *prerm, + PostRemove: *postrm, + }, + }, + }) + + if len(*depends) != 0 { + info.Overridables.Depends = strings.Split(*depends, ",") + } + if len(*recommends) != 0 { + info.Overridables.Recommends = strings.Split(*recommends, ",") + } + if *replaces != "" { + info.Overridables.Replaces = []string{*replaces} + info.Overridables.Conflicts = []string{*replaces} + } + + switch *pkgType { + case "deb": + info.Section = "net" + info.Priority = "extra" + case "rpm": + info.Overridables.RPM.Group = "Network" + } + + pkg, err := nfpm.Get(*pkgType) + if err != nil { + log.Fatalf("Getting packager for %q: %v", *pkgType, err) + } + + f, err := os.Create(*out) + if err != nil { + log.Fatalf("Creating output file %q: %v", *out, err) + } + defer f.Close() + + if err := pkg.Package(info, f); err != nil { + log.Fatalf("Creating package %q: %v", *out, err) + } +} diff --git a/cmd/mkversion/mkversion.go b/cmd/mkversion/mkversion.go index 6a6a18a50..c8c8bf179 100644 --- a/cmd/mkversion/mkversion.go +++ b/cmd/mkversion/mkversion.go @@ -1,44 +1,44 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// mkversion gets version info from git and outputs a bunch of shell variables -// that get used elsewhere in the build system to embed version numbers into -// binaries. -package main - -import ( - "bufio" - "bytes" - "fmt" - "io" - "os" - "time" - - "tailscale.com/tailcfg" - "tailscale.com/version/mkversion" -) - -func main() { - prefix := "" - if len(os.Args) > 1 { - if os.Args[1] == "--export" { - prefix = "export " - } else { - fmt.Println("usage: mkversion [--export|-h|--help]") - os.Exit(1) - } - } - - var b bytes.Buffer - io.WriteString(&b, mkversion.Info().String()) - // Copyright and the client capability are not part of the version - // information, but similarly used in Xcode builds to embed in the metadata, - // thus generate them now. - copyright := fmt.Sprintf("Copyright © %d Tailscale Inc. All Rights Reserved.", time.Now().Year()) - fmt.Fprintf(&b, "VERSION_COPYRIGHT=%q\n", copyright) - fmt.Fprintf(&b, "VERSION_CAPABILITY=%d\n", tailcfg.CurrentCapabilityVersion) - s := bufio.NewScanner(&b) - for s.Scan() { - fmt.Println(prefix + s.Text()) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// mkversion gets version info from git and outputs a bunch of shell variables +// that get used elsewhere in the build system to embed version numbers into +// binaries. +package main + +import ( + "bufio" + "bytes" + "fmt" + "io" + "os" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/version/mkversion" +) + +func main() { + prefix := "" + if len(os.Args) > 1 { + if os.Args[1] == "--export" { + prefix = "export " + } else { + fmt.Println("usage: mkversion [--export|-h|--help]") + os.Exit(1) + } + } + + var b bytes.Buffer + io.WriteString(&b, mkversion.Info().String()) + // Copyright and the client capability are not part of the version + // information, but similarly used in Xcode builds to embed in the metadata, + // thus generate them now. + copyright := fmt.Sprintf("Copyright © %d Tailscale Inc. All Rights Reserved.", time.Now().Year()) + fmt.Fprintf(&b, "VERSION_COPYRIGHT=%q\n", copyright) + fmt.Fprintf(&b, "VERSION_CAPABILITY=%d\n", tailcfg.CurrentCapabilityVersion) + s := bufio.NewScanner(&b) + for s.Scan() { + fmt.Println(prefix + s.Text()) + } +} diff --git a/cmd/nardump/README.md b/cmd/nardump/README.md index 6c73ff9b0..6fa7fc2f1 100644 --- a/cmd/nardump/README.md +++ b/cmd/nardump/README.md @@ -1,7 +1,7 @@ -# nardump - -nardump is like nix-store --dump, but in Go, writing a NAR file (tar-like, -but focused on being reproducible) to stdout or to a hash with the --sri flag. - -It lets us calculate the Nix sha256 in shell.nix without the person running -git-pull-oss.sh having Nix available. +# nardump + +nardump is like nix-store --dump, but in Go, writing a NAR file (tar-like, +but focused on being reproducible) to stdout or to a hash with the --sri flag. + +It lets us calculate the Nix sha256 in shell.nix without the person running +git-pull-oss.sh having Nix available. diff --git a/cmd/nardump/nardump.go b/cmd/nardump/nardump.go index 241475537..05be7b65a 100644 --- a/cmd/nardump/nardump.go +++ b/cmd/nardump/nardump.go @@ -1,184 +1,184 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// nardump is like nix-store --dump, but in Go, writing a NAR -// file (tar-like, but focused on being reproducible) to stdout -// or to a hash with the --sri flag. -// -// It lets us calculate a Nix sha256 without the person running -// git-pull-oss.sh having Nix available. -package main - -// For the format, see: -// See https://gist.github.com/jbeda/5c79d2b1434f0018d693 - -import ( - "bufio" - "crypto/sha256" - "encoding/base64" - "encoding/binary" - "flag" - "fmt" - "io" - "io/fs" - "log" - "os" - "path" - "sort" -) - -var sri = flag.Bool("sri", false, "print SRI") - -func main() { - flag.Parse() - if flag.NArg() != 1 { - log.Fatal("usage: nardump ") - } - arg := flag.Arg(0) - if err := os.Chdir(arg); err != nil { - log.Fatal(err) - } - if *sri { - hash := sha256.New() - if err := writeNAR(hash, os.DirFS(".")); err != nil { - log.Fatal(err) - } - fmt.Printf("sha256-%s\n", base64.StdEncoding.EncodeToString(hash.Sum(nil))) - return - } - bw := bufio.NewWriter(os.Stdout) - if err := writeNAR(bw, os.DirFS(".")); err != nil { - log.Fatal(err) - } - bw.Flush() -} - -// writeNARError is a sentinel panic type that's recovered by writeNAR -// and converted into the wrapped error. -type writeNARError struct{ err error } - -// narWriter writes NAR files. -type narWriter struct { - w io.Writer - fs fs.FS -} - -// writeNAR writes a NAR file to w from the root of fs. -func writeNAR(w io.Writer, fs fs.FS) (err error) { - defer func() { - if e := recover(); e != nil { - if we, ok := e.(writeNARError); ok { - err = we.err - return - } - panic(e) - } - }() - nw := &narWriter{w: w, fs: fs} - nw.str("nix-archive-1") - return nw.writeDir(".") -} - -func (nw *narWriter) writeDir(dirPath string) error { - ents, err := fs.ReadDir(nw.fs, dirPath) - if err != nil { - return err - } - sort.Slice(ents, func(i, j int) bool { - return ents[i].Name() < ents[j].Name() - }) - nw.str("(") - nw.str("type") - nw.str("directory") - for _, ent := range ents { - nw.str("entry") - nw.str("(") - nw.str("name") - nw.str(ent.Name()) - nw.str("node") - mode := ent.Type() - sub := path.Join(dirPath, ent.Name()) - var err error - switch { - case mode.IsRegular(): - err = nw.writeRegular(sub) - case mode.IsDir(): - err = nw.writeDir(sub) - default: - // TODO(bradfitz): symlink, but requires fighting io/fs a bit - // to get at Readlink or the osFS via fs. But for now - // we don't need symlinks because they're not in Go's archive. - return fmt.Errorf("unsupported file type %v at %q", sub, mode) - } - if err != nil { - return err - } - nw.str(")") - } - nw.str(")") - return nil -} - -func (nw *narWriter) writeRegular(path string) error { - nw.str("(") - nw.str("type") - nw.str("regular") - fi, err := fs.Stat(nw.fs, path) - if err != nil { - return err - } - if fi.Mode()&0111 != 0 { - nw.str("executable") - nw.str("") - } - contents, err := fs.ReadFile(nw.fs, path) - if err != nil { - return err - } - nw.str("contents") - if err := writeBytes(nw.w, contents); err != nil { - return err - } - nw.str(")") - return nil -} - -func (nw *narWriter) str(s string) { - if err := writeString(nw.w, s); err != nil { - panic(writeNARError{err}) - } -} - -func writeString(w io.Writer, s string) error { - var buf [8]byte - binary.LittleEndian.PutUint64(buf[:], uint64(len(s))) - if _, err := w.Write(buf[:]); err != nil { - return err - } - if _, err := io.WriteString(w, s); err != nil { - return err - } - return writePad(w, len(s)) -} - -func writeBytes(w io.Writer, b []byte) error { - var buf [8]byte - binary.LittleEndian.PutUint64(buf[:], uint64(len(b))) - if _, err := w.Write(buf[:]); err != nil { - return err - } - if _, err := w.Write(b); err != nil { - return err - } - return writePad(w, len(b)) -} - -func writePad(w io.Writer, n int) error { - pad := n % 8 - if pad == 0 { - return nil - } - var zeroes [8]byte - _, err := w.Write(zeroes[:8-pad]) - return err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// nardump is like nix-store --dump, but in Go, writing a NAR +// file (tar-like, but focused on being reproducible) to stdout +// or to a hash with the --sri flag. +// +// It lets us calculate a Nix sha256 without the person running +// git-pull-oss.sh having Nix available. +package main + +// For the format, see: +// See https://gist.github.com/jbeda/5c79d2b1434f0018d693 + +import ( + "bufio" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "flag" + "fmt" + "io" + "io/fs" + "log" + "os" + "path" + "sort" +) + +var sri = flag.Bool("sri", false, "print SRI") + +func main() { + flag.Parse() + if flag.NArg() != 1 { + log.Fatal("usage: nardump ") + } + arg := flag.Arg(0) + if err := os.Chdir(arg); err != nil { + log.Fatal(err) + } + if *sri { + hash := sha256.New() + if err := writeNAR(hash, os.DirFS(".")); err != nil { + log.Fatal(err) + } + fmt.Printf("sha256-%s\n", base64.StdEncoding.EncodeToString(hash.Sum(nil))) + return + } + bw := bufio.NewWriter(os.Stdout) + if err := writeNAR(bw, os.DirFS(".")); err != nil { + log.Fatal(err) + } + bw.Flush() +} + +// writeNARError is a sentinel panic type that's recovered by writeNAR +// and converted into the wrapped error. +type writeNARError struct{ err error } + +// narWriter writes NAR files. +type narWriter struct { + w io.Writer + fs fs.FS +} + +// writeNAR writes a NAR file to w from the root of fs. +func writeNAR(w io.Writer, fs fs.FS) (err error) { + defer func() { + if e := recover(); e != nil { + if we, ok := e.(writeNARError); ok { + err = we.err + return + } + panic(e) + } + }() + nw := &narWriter{w: w, fs: fs} + nw.str("nix-archive-1") + return nw.writeDir(".") +} + +func (nw *narWriter) writeDir(dirPath string) error { + ents, err := fs.ReadDir(nw.fs, dirPath) + if err != nil { + return err + } + sort.Slice(ents, func(i, j int) bool { + return ents[i].Name() < ents[j].Name() + }) + nw.str("(") + nw.str("type") + nw.str("directory") + for _, ent := range ents { + nw.str("entry") + nw.str("(") + nw.str("name") + nw.str(ent.Name()) + nw.str("node") + mode := ent.Type() + sub := path.Join(dirPath, ent.Name()) + var err error + switch { + case mode.IsRegular(): + err = nw.writeRegular(sub) + case mode.IsDir(): + err = nw.writeDir(sub) + default: + // TODO(bradfitz): symlink, but requires fighting io/fs a bit + // to get at Readlink or the osFS via fs. But for now + // we don't need symlinks because they're not in Go's archive. + return fmt.Errorf("unsupported file type %v at %q", sub, mode) + } + if err != nil { + return err + } + nw.str(")") + } + nw.str(")") + return nil +} + +func (nw *narWriter) writeRegular(path string) error { + nw.str("(") + nw.str("type") + nw.str("regular") + fi, err := fs.Stat(nw.fs, path) + if err != nil { + return err + } + if fi.Mode()&0111 != 0 { + nw.str("executable") + nw.str("") + } + contents, err := fs.ReadFile(nw.fs, path) + if err != nil { + return err + } + nw.str("contents") + if err := writeBytes(nw.w, contents); err != nil { + return err + } + nw.str(")") + return nil +} + +func (nw *narWriter) str(s string) { + if err := writeString(nw.w, s); err != nil { + panic(writeNARError{err}) + } +} + +func writeString(w io.Writer, s string) error { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], uint64(len(s))) + if _, err := w.Write(buf[:]); err != nil { + return err + } + if _, err := io.WriteString(w, s); err != nil { + return err + } + return writePad(w, len(s)) +} + +func writeBytes(w io.Writer, b []byte) error { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], uint64(len(b))) + if _, err := w.Write(buf[:]); err != nil { + return err + } + if _, err := w.Write(b); err != nil { + return err + } + return writePad(w, len(b)) +} + +func writePad(w io.Writer, n int) error { + pad := n % 8 + if pad == 0 { + return nil + } + var zeroes [8]byte + _, err := w.Write(zeroes[:8-pad]) + return err +} diff --git a/cmd/nginx-auth/.gitignore b/cmd/nginx-auth/.gitignore index 255276578..3c608aeb1 100644 --- a/cmd/nginx-auth/.gitignore +++ b/cmd/nginx-auth/.gitignore @@ -1,4 +1,4 @@ -nga.sock -*.deb -*.rpm -tailscale.nginx-auth +nga.sock +*.deb +*.rpm +tailscale.nginx-auth diff --git a/cmd/nginx-auth/README.md b/cmd/nginx-auth/README.md index 869b1487b..858f9ab81 100644 --- a/cmd/nginx-auth/README.md +++ b/cmd/nginx-auth/README.md @@ -1,161 +1,161 @@ -# nginx-auth - -[![status: experimental](https://img.shields.io/badge/status-experimental-blue)](https://tailscale.com/kb/1167/release-stages/#experimental) - -This is a tool that allows users to use Tailscale Whois authentication with -NGINX as a reverse proxy. This allows users that already have a bunch of -services hosted on an internal NGINX server to point those domains to the -Tailscale IP of the NGINX server and then seamlessly use Tailscale for -authentication. - -Many thanks to [@zrail](https://twitter.com/zrail/status/1511788463586222087) on -Twitter for introducing the basic idea and offering some sample code. This -program is based on that sample code with security enhancements. Namely: - -* This listens over a UNIX socket instead of a TCP socket, to prevent - leakage to the network -* This uses systemd socket activation so that systemd owns the socket - and can then lock down the service to the bare minimum required to do - its job without having to worry about dropping permissions -* This provides additional information in HTTP response headers that can - be useful for integrating with various services - -## Configuration - -In order to protect a service with this tool, do the following in the respective -`server` block: - -Create an authentication location with the `internal` flag set: - -```nginx -location /auth { - internal; - - proxy_pass http://unix:/run/tailscale.nginx-auth.sock; - proxy_pass_request_body off; - - proxy_set_header Host $http_host; - proxy_set_header Remote-Addr $remote_addr; - proxy_set_header Remote-Port $remote_port; - proxy_set_header Original-URI $request_uri; -} -``` - -Then add the following to the `location /` block: - -``` -auth_request /auth; -auth_request_set $auth_user $upstream_http_tailscale_user; -auth_request_set $auth_name $upstream_http_tailscale_name; -auth_request_set $auth_login $upstream_http_tailscale_login; -auth_request_set $auth_tailnet $upstream_http_tailscale_tailnet; -auth_request_set $auth_profile_picture $upstream_http_tailscale_profile_picture; - -proxy_set_header X-Webauth-User "$auth_user"; -proxy_set_header X-Webauth-Name "$auth_name"; -proxy_set_header X-Webauth-Login "$auth_login"; -proxy_set_header X-Webauth-Tailnet "$auth_tailnet"; -proxy_set_header X-Webauth-Profile-Picture "$auth_profile_picture"; -``` - -When this configuration is used with a Go HTTP handler such as this: - -```go -http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { - e := json.NewEncoder(w) - e.SetIndent("", " ") - e.Encode(r.Header) -}) -``` - -You will get output like this: - -```json -{ - "Accept": [ - "*/*" - ], - "Connection": [ - "upgrade" - ], - "User-Agent": [ - "curl/7.82.0" - ], - "X-Webauth-Login": [ - "Xe" - ], - "X-Webauth-Name": [ - "Xe Iaso" - ], - "X-Webauth-Profile-Picture": [ - "https://avatars.githubusercontent.com/u/529003?v=4" - ], - "X-Webauth-Tailnet": [ - "cetacean.org.github" - ] - "X-Webauth-User": [ - "Xe@github" - ] -} -``` - -## Headers - -The authentication service provides the following headers to decorate your -proxied requests: - -| Header | Example Value | Description | -| :------ | :-------------- | :---------- | -| `Tailscale-User` | `azurediamond@hunter2.net` | The Tailscale username the remote machine is logged in as in user@host form | -| `Tailscale-Login` | `azurediamond` | The user portion of the Tailscale username the remote machine is logged in as | -| `Tailscale-Name` | `Azure Diamond` | The "real name" of the Tailscale user the machine is logged in as | -| `Tailscale-Profile-Picture` | `https://i.kym-cdn.com/photos/images/newsfeed/001/065/963/ae0.png` | The profile picture provided by the Identity Provider your tailnet uses | -| `Tailscale-Tailnet` | `hunter2.net` | The tailnet name | - -Most of the time you can set `X-Webauth-User` to the contents of the -`Tailscale-User` header, but some services may not accept a username with an `@` -symbol in it. If this is the case, set `X-Webauth-User` to the `Tailscale-Login` -header. - -The `Tailscale-Tailnet` header can help you identify which tailnet the session -is coming from. If you are using node sharing, this can help you make sure that -you aren't giving administrative access to people outside your tailnet. - -### Allow Requests From Only One Tailnet - -If you want to prevent node sharing from allowing users to access a service, add -the `Expected-Tailnet` header to your auth request: - -```nginx -location /auth { - # ... - proxy_set_header Expected-Tailnet "tailnet012345.ts.net"; -} -``` - -If a user from a different tailnet tries to use that service, this will return a -generic "forbidden" error page: - -```html - -403 Forbidden - -

403 Forbidden

-
nginx/1.18.0 (Ubuntu)
- - -``` - -You can get the tailnet name from [the admin panel](https://login.tailscale.com/admin/dns). - -## Building - -Install `cmd/mkpkg`: - -``` -cd .. && go install ./mkpkg -``` - -Then run `./mkdeb.sh`. It will emit a `.deb` and `.rpm` package for amd64 -machines (Linux uname flag: `x86_64`). You can add these to your deployment -methods as you see fit. +# nginx-auth + +[![status: experimental](https://img.shields.io/badge/status-experimental-blue)](https://tailscale.com/kb/1167/release-stages/#experimental) + +This is a tool that allows users to use Tailscale Whois authentication with +NGINX as a reverse proxy. This allows users that already have a bunch of +services hosted on an internal NGINX server to point those domains to the +Tailscale IP of the NGINX server and then seamlessly use Tailscale for +authentication. + +Many thanks to [@zrail](https://twitter.com/zrail/status/1511788463586222087) on +Twitter for introducing the basic idea and offering some sample code. This +program is based on that sample code with security enhancements. Namely: + +* This listens over a UNIX socket instead of a TCP socket, to prevent + leakage to the network +* This uses systemd socket activation so that systemd owns the socket + and can then lock down the service to the bare minimum required to do + its job without having to worry about dropping permissions +* This provides additional information in HTTP response headers that can + be useful for integrating with various services + +## Configuration + +In order to protect a service with this tool, do the following in the respective +`server` block: + +Create an authentication location with the `internal` flag set: + +```nginx +location /auth { + internal; + + proxy_pass http://unix:/run/tailscale.nginx-auth.sock; + proxy_pass_request_body off; + + proxy_set_header Host $http_host; + proxy_set_header Remote-Addr $remote_addr; + proxy_set_header Remote-Port $remote_port; + proxy_set_header Original-URI $request_uri; +} +``` + +Then add the following to the `location /` block: + +``` +auth_request /auth; +auth_request_set $auth_user $upstream_http_tailscale_user; +auth_request_set $auth_name $upstream_http_tailscale_name; +auth_request_set $auth_login $upstream_http_tailscale_login; +auth_request_set $auth_tailnet $upstream_http_tailscale_tailnet; +auth_request_set $auth_profile_picture $upstream_http_tailscale_profile_picture; + +proxy_set_header X-Webauth-User "$auth_user"; +proxy_set_header X-Webauth-Name "$auth_name"; +proxy_set_header X-Webauth-Login "$auth_login"; +proxy_set_header X-Webauth-Tailnet "$auth_tailnet"; +proxy_set_header X-Webauth-Profile-Picture "$auth_profile_picture"; +``` + +When this configuration is used with a Go HTTP handler such as this: + +```go +http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { + e := json.NewEncoder(w) + e.SetIndent("", " ") + e.Encode(r.Header) +}) +``` + +You will get output like this: + +```json +{ + "Accept": [ + "*/*" + ], + "Connection": [ + "upgrade" + ], + "User-Agent": [ + "curl/7.82.0" + ], + "X-Webauth-Login": [ + "Xe" + ], + "X-Webauth-Name": [ + "Xe Iaso" + ], + "X-Webauth-Profile-Picture": [ + "https://avatars.githubusercontent.com/u/529003?v=4" + ], + "X-Webauth-Tailnet": [ + "cetacean.org.github" + ] + "X-Webauth-User": [ + "Xe@github" + ] +} +``` + +## Headers + +The authentication service provides the following headers to decorate your +proxied requests: + +| Header | Example Value | Description | +| :------ | :-------------- | :---------- | +| `Tailscale-User` | `azurediamond@hunter2.net` | The Tailscale username the remote machine is logged in as in user@host form | +| `Tailscale-Login` | `azurediamond` | The user portion of the Tailscale username the remote machine is logged in as | +| `Tailscale-Name` | `Azure Diamond` | The "real name" of the Tailscale user the machine is logged in as | +| `Tailscale-Profile-Picture` | `https://i.kym-cdn.com/photos/images/newsfeed/001/065/963/ae0.png` | The profile picture provided by the Identity Provider your tailnet uses | +| `Tailscale-Tailnet` | `hunter2.net` | The tailnet name | + +Most of the time you can set `X-Webauth-User` to the contents of the +`Tailscale-User` header, but some services may not accept a username with an `@` +symbol in it. If this is the case, set `X-Webauth-User` to the `Tailscale-Login` +header. + +The `Tailscale-Tailnet` header can help you identify which tailnet the session +is coming from. If you are using node sharing, this can help you make sure that +you aren't giving administrative access to people outside your tailnet. + +### Allow Requests From Only One Tailnet + +If you want to prevent node sharing from allowing users to access a service, add +the `Expected-Tailnet` header to your auth request: + +```nginx +location /auth { + # ... + proxy_set_header Expected-Tailnet "tailnet012345.ts.net"; +} +``` + +If a user from a different tailnet tries to use that service, this will return a +generic "forbidden" error page: + +```html + +403 Forbidden + +

403 Forbidden

+
nginx/1.18.0 (Ubuntu)
+ + +``` + +You can get the tailnet name from [the admin panel](https://login.tailscale.com/admin/dns). + +## Building + +Install `cmd/mkpkg`: + +``` +cd .. && go install ./mkpkg +``` + +Then run `./mkdeb.sh`. It will emit a `.deb` and `.rpm` package for amd64 +machines (Linux uname flag: `x86_64`). You can add these to your deployment +methods as you see fit. diff --git a/cmd/nginx-auth/deb/postinst.sh b/cmd/nginx-auth/deb/postinst.sh index e692ced07..d352a8488 100755 --- a/cmd/nginx-auth/deb/postinst.sh +++ b/cmd/nginx-auth/deb/postinst.sh @@ -1,14 +1,14 @@ -if [ "$1" = "configure" ] || [ "$1" = "abort-upgrade" ] || [ "$1" = "abort-deconfigure" ] || [ "$1" = "abort-remove" ] ; then - deb-systemd-helper unmask 'tailscale.nginx-auth.socket' >/dev/null || true - if deb-systemd-helper --quiet was-enabled 'tailscale.nginx-auth.socket'; then - deb-systemd-helper enable 'tailscale.nginx-auth.socket' >/dev/null || true - else - deb-systemd-helper update-state 'tailscale.nginx-auth.socket' >/dev/null || true - fi - - if systemctl is-active tailscale.nginx-auth.socket >/dev/null; then - systemctl --system daemon-reload >/dev/null || true - deb-systemd-invoke stop 'tailscale.nginx-auth.service' >/dev/null || true - deb-systemd-invoke restart 'tailscale.nginx-auth.socket' >/dev/null || true - fi -fi +if [ "$1" = "configure" ] || [ "$1" = "abort-upgrade" ] || [ "$1" = "abort-deconfigure" ] || [ "$1" = "abort-remove" ] ; then + deb-systemd-helper unmask 'tailscale.nginx-auth.socket' >/dev/null || true + if deb-systemd-helper --quiet was-enabled 'tailscale.nginx-auth.socket'; then + deb-systemd-helper enable 'tailscale.nginx-auth.socket' >/dev/null || true + else + deb-systemd-helper update-state 'tailscale.nginx-auth.socket' >/dev/null || true + fi + + if systemctl is-active tailscale.nginx-auth.socket >/dev/null; then + systemctl --system daemon-reload >/dev/null || true + deb-systemd-invoke stop 'tailscale.nginx-auth.service' >/dev/null || true + deb-systemd-invoke restart 'tailscale.nginx-auth.socket' >/dev/null || true + fi +fi diff --git a/cmd/nginx-auth/deb/postrm.sh b/cmd/nginx-auth/deb/postrm.sh index 7870efd18..4bce86139 100755 --- a/cmd/nginx-auth/deb/postrm.sh +++ b/cmd/nginx-auth/deb/postrm.sh @@ -1,19 +1,19 @@ -#!/bin/sh -set -e -if [ -d /run/systemd/system ] ; then - systemctl --system daemon-reload >/dev/null || true -fi - -if [ -x "/usr/bin/deb-systemd-helper" ]; then - if [ "$1" = "remove" ]; then - deb-systemd-helper mask 'tailscale.nginx-auth.socket' >/dev/null || true - deb-systemd-helper mask 'tailscale.nginx-auth.service' >/dev/null || true - fi - - if [ "$1" = "purge" ]; then - deb-systemd-helper purge 'tailscale.nginx-auth.socket' >/dev/null || true - deb-systemd-helper unmask 'tailscale.nginx-auth.socket' >/dev/null || true - deb-systemd-helper purge 'tailscale.nginx-auth.service' >/dev/null || true - deb-systemd-helper unmask 'tailscale.nginx-auth.service' >/dev/null || true - fi -fi +#!/bin/sh +set -e +if [ -d /run/systemd/system ] ; then + systemctl --system daemon-reload >/dev/null || true +fi + +if [ -x "/usr/bin/deb-systemd-helper" ]; then + if [ "$1" = "remove" ]; then + deb-systemd-helper mask 'tailscale.nginx-auth.socket' >/dev/null || true + deb-systemd-helper mask 'tailscale.nginx-auth.service' >/dev/null || true + fi + + if [ "$1" = "purge" ]; then + deb-systemd-helper purge 'tailscale.nginx-auth.socket' >/dev/null || true + deb-systemd-helper unmask 'tailscale.nginx-auth.socket' >/dev/null || true + deb-systemd-helper purge 'tailscale.nginx-auth.service' >/dev/null || true + deb-systemd-helper unmask 'tailscale.nginx-auth.service' >/dev/null || true + fi +fi diff --git a/cmd/nginx-auth/deb/prerm.sh b/cmd/nginx-auth/deb/prerm.sh index 22be23387..e4becd170 100755 --- a/cmd/nginx-auth/deb/prerm.sh +++ b/cmd/nginx-auth/deb/prerm.sh @@ -1,8 +1,8 @@ -#!/bin/sh -set -e -if [ "$1" = "remove" ]; then - if [ -d /run/systemd/system ]; then - deb-systemd-invoke stop 'tailscale.nginx-auth.service' >/dev/null || true - deb-systemd-invoke stop 'tailscale.nginx-auth.socket' >/dev/null || true - fi -fi +#!/bin/sh +set -e +if [ "$1" = "remove" ]; then + if [ -d /run/systemd/system ]; then + deb-systemd-invoke stop 'tailscale.nginx-auth.service' >/dev/null || true + deb-systemd-invoke stop 'tailscale.nginx-auth.socket' >/dev/null || true + fi +fi diff --git a/cmd/nginx-auth/mkdeb.sh b/cmd/nginx-auth/mkdeb.sh index 6a5721093..59f43230d 100755 --- a/cmd/nginx-auth/mkdeb.sh +++ b/cmd/nginx-auth/mkdeb.sh @@ -1,32 +1,32 @@ -#!/usr/bin/env bash - -set -e - -VERSION=0.1.3 -for ARCH in amd64 arm64; do - CGO_ENABLED=0 GOARCH=${ARCH} GOOS=linux go build -o tailscale.nginx-auth . - - mkpkg \ - --out=tailscale-nginx-auth-${VERSION}-${ARCH}.deb \ - --name=tailscale-nginx-auth \ - --version=${VERSION} \ - --type=deb \ - --arch=${ARCH} \ - --postinst=deb/postinst.sh \ - --postrm=deb/postrm.sh \ - --prerm=deb/prerm.sh \ - --description="Tailscale NGINX authentication protocol handler" \ - --files=./tailscale.nginx-auth:/usr/sbin/tailscale.nginx-auth,./tailscale.nginx-auth.socket:/lib/systemd/system/tailscale.nginx-auth.socket,./tailscale.nginx-auth.service:/lib/systemd/system/tailscale.nginx-auth.service,./README.md:/usr/share/tailscale/nginx-auth/README.md - - mkpkg \ - --out=tailscale-nginx-auth-${VERSION}-${ARCH}.rpm \ - --name=tailscale-nginx-auth \ - --version=${VERSION} \ - --type=rpm \ - --arch=${ARCH} \ - --postinst=rpm/postinst.sh \ - --postrm=rpm/postrm.sh \ - --prerm=rpm/prerm.sh \ - --description="Tailscale NGINX authentication protocol handler" \ - --files=./tailscale.nginx-auth:/usr/sbin/tailscale.nginx-auth,./tailscale.nginx-auth.socket:/lib/systemd/system/tailscale.nginx-auth.socket,./tailscale.nginx-auth.service:/lib/systemd/system/tailscale.nginx-auth.service,./README.md:/usr/share/tailscale/nginx-auth/README.md -done +#!/usr/bin/env bash + +set -e + +VERSION=0.1.3 +for ARCH in amd64 arm64; do + CGO_ENABLED=0 GOARCH=${ARCH} GOOS=linux go build -o tailscale.nginx-auth . + + mkpkg \ + --out=tailscale-nginx-auth-${VERSION}-${ARCH}.deb \ + --name=tailscale-nginx-auth \ + --version=${VERSION} \ + --type=deb \ + --arch=${ARCH} \ + --postinst=deb/postinst.sh \ + --postrm=deb/postrm.sh \ + --prerm=deb/prerm.sh \ + --description="Tailscale NGINX authentication protocol handler" \ + --files=./tailscale.nginx-auth:/usr/sbin/tailscale.nginx-auth,./tailscale.nginx-auth.socket:/lib/systemd/system/tailscale.nginx-auth.socket,./tailscale.nginx-auth.service:/lib/systemd/system/tailscale.nginx-auth.service,./README.md:/usr/share/tailscale/nginx-auth/README.md + + mkpkg \ + --out=tailscale-nginx-auth-${VERSION}-${ARCH}.rpm \ + --name=tailscale-nginx-auth \ + --version=${VERSION} \ + --type=rpm \ + --arch=${ARCH} \ + --postinst=rpm/postinst.sh \ + --postrm=rpm/postrm.sh \ + --prerm=rpm/prerm.sh \ + --description="Tailscale NGINX authentication protocol handler" \ + --files=./tailscale.nginx-auth:/usr/sbin/tailscale.nginx-auth,./tailscale.nginx-auth.socket:/lib/systemd/system/tailscale.nginx-auth.socket,./tailscale.nginx-auth.service:/lib/systemd/system/tailscale.nginx-auth.service,./README.md:/usr/share/tailscale/nginx-auth/README.md +done diff --git a/cmd/nginx-auth/nginx-auth.go b/cmd/nginx-auth/nginx-auth.go index befcb6d6c..09da74da1 100644 --- a/cmd/nginx-auth/nginx-auth.go +++ b/cmd/nginx-auth/nginx-auth.go @@ -1,128 +1,128 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -// Command nginx-auth is a tool that allows users to use Tailscale Whois -// authentication with NGINX as a reverse proxy. This allows users that -// already have a bunch of services hosted on an internal NGINX server -// to point those domains to the Tailscale IP of the NGINX server and -// then seamlessly use Tailscale for authentication. -package main - -import ( - "flag" - "log" - "net" - "net/http" - "net/netip" - "net/url" - "os" - "strings" - - "github.com/coreos/go-systemd/activation" - "tailscale.com/client/tailscale" -) - -var ( - sockPath = flag.String("sockpath", "", "the filesystem path for the unix socket this service exposes") -) - -func main() { - flag.Parse() - - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - remoteHost := r.Header.Get("Remote-Addr") - remotePort := r.Header.Get("Remote-Port") - if remoteHost == "" || remotePort == "" { - w.WriteHeader(http.StatusBadRequest) - log.Println("set Remote-Addr to $remote_addr and Remote-Port to $remote_port in your nginx config") - return - } - - remoteAddrStr := net.JoinHostPort(remoteHost, remotePort) - remoteAddr, err := netip.ParseAddrPort(remoteAddrStr) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) - log.Printf("remote address and port are not valid: %v", err) - return - } - - info, err := tailscale.WhoIs(r.Context(), remoteAddr.String()) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) - log.Printf("can't look up %s: %v", remoteAddr, err) - return - } - - if info.Node.IsTagged() { - w.WriteHeader(http.StatusForbidden) - log.Printf("node %s is tagged", info.Node.Hostinfo.Hostname()) - return - } - - // tailnet of connected node. When accessing shared nodes, this - // will be empty because the tailnet of the sharee is not exposed. - var tailnet string - - if !info.Node.Hostinfo.ShareeNode() { - var ok bool - _, tailnet, ok = strings.Cut(info.Node.Name, info.Node.ComputedName+".") - if !ok { - w.WriteHeader(http.StatusUnauthorized) - log.Printf("can't extract tailnet name from hostname %q", info.Node.Name) - return - } - tailnet = strings.TrimSuffix(tailnet, ".beta.tailscale.net") - } - - if expectedTailnet := r.Header.Get("Expected-Tailnet"); expectedTailnet != "" && expectedTailnet != tailnet { - w.WriteHeader(http.StatusForbidden) - log.Printf("user is part of tailnet %s, wanted: %s", tailnet, url.QueryEscape(expectedTailnet)) - return - } - - h := w.Header() - h.Set("Tailscale-Login", strings.Split(info.UserProfile.LoginName, "@")[0]) - h.Set("Tailscale-User", info.UserProfile.LoginName) - h.Set("Tailscale-Name", info.UserProfile.DisplayName) - h.Set("Tailscale-Profile-Picture", info.UserProfile.ProfilePicURL) - h.Set("Tailscale-Tailnet", tailnet) - w.WriteHeader(http.StatusNoContent) - }) - - if *sockPath != "" { - _ = os.Remove(*sockPath) // ignore error, this file may not already exist - ln, err := net.Listen("unix", *sockPath) - if err != nil { - log.Fatalf("can't listen on %s: %v", *sockPath, err) - } - defer ln.Close() - - log.Printf("listening on %s", *sockPath) - log.Fatal(http.Serve(ln, mux)) - } - - listeners, err := activation.Listeners() - if err != nil { - log.Fatalf("no sockets passed to this service with systemd: %v", err) - } - - // NOTE(Xe): normally you'd want to make a waitgroup here and then register - // each listener with it. In this case I want this to blow up horribly if - // any of the listeners stop working. systemd will restart it due to the - // socket activation at play. - // - // TL;DR: Let it crash, it will come back - for _, ln := range listeners { - go func(ln net.Listener) { - log.Printf("listening on %s", ln.Addr()) - log.Fatal(http.Serve(ln, mux)) - }(ln) - } - - for { - select {} - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +// Command nginx-auth is a tool that allows users to use Tailscale Whois +// authentication with NGINX as a reverse proxy. This allows users that +// already have a bunch of services hosted on an internal NGINX server +// to point those domains to the Tailscale IP of the NGINX server and +// then seamlessly use Tailscale for authentication. +package main + +import ( + "flag" + "log" + "net" + "net/http" + "net/netip" + "net/url" + "os" + "strings" + + "github.com/coreos/go-systemd/activation" + "tailscale.com/client/tailscale" +) + +var ( + sockPath = flag.String("sockpath", "", "the filesystem path for the unix socket this service exposes") +) + +func main() { + flag.Parse() + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + remoteHost := r.Header.Get("Remote-Addr") + remotePort := r.Header.Get("Remote-Port") + if remoteHost == "" || remotePort == "" { + w.WriteHeader(http.StatusBadRequest) + log.Println("set Remote-Addr to $remote_addr and Remote-Port to $remote_port in your nginx config") + return + } + + remoteAddrStr := net.JoinHostPort(remoteHost, remotePort) + remoteAddr, err := netip.ParseAddrPort(remoteAddrStr) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + log.Printf("remote address and port are not valid: %v", err) + return + } + + info, err := tailscale.WhoIs(r.Context(), remoteAddr.String()) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + log.Printf("can't look up %s: %v", remoteAddr, err) + return + } + + if info.Node.IsTagged() { + w.WriteHeader(http.StatusForbidden) + log.Printf("node %s is tagged", info.Node.Hostinfo.Hostname()) + return + } + + // tailnet of connected node. When accessing shared nodes, this + // will be empty because the tailnet of the sharee is not exposed. + var tailnet string + + if !info.Node.Hostinfo.ShareeNode() { + var ok bool + _, tailnet, ok = strings.Cut(info.Node.Name, info.Node.ComputedName+".") + if !ok { + w.WriteHeader(http.StatusUnauthorized) + log.Printf("can't extract tailnet name from hostname %q", info.Node.Name) + return + } + tailnet = strings.TrimSuffix(tailnet, ".beta.tailscale.net") + } + + if expectedTailnet := r.Header.Get("Expected-Tailnet"); expectedTailnet != "" && expectedTailnet != tailnet { + w.WriteHeader(http.StatusForbidden) + log.Printf("user is part of tailnet %s, wanted: %s", tailnet, url.QueryEscape(expectedTailnet)) + return + } + + h := w.Header() + h.Set("Tailscale-Login", strings.Split(info.UserProfile.LoginName, "@")[0]) + h.Set("Tailscale-User", info.UserProfile.LoginName) + h.Set("Tailscale-Name", info.UserProfile.DisplayName) + h.Set("Tailscale-Profile-Picture", info.UserProfile.ProfilePicURL) + h.Set("Tailscale-Tailnet", tailnet) + w.WriteHeader(http.StatusNoContent) + }) + + if *sockPath != "" { + _ = os.Remove(*sockPath) // ignore error, this file may not already exist + ln, err := net.Listen("unix", *sockPath) + if err != nil { + log.Fatalf("can't listen on %s: %v", *sockPath, err) + } + defer ln.Close() + + log.Printf("listening on %s", *sockPath) + log.Fatal(http.Serve(ln, mux)) + } + + listeners, err := activation.Listeners() + if err != nil { + log.Fatalf("no sockets passed to this service with systemd: %v", err) + } + + // NOTE(Xe): normally you'd want to make a waitgroup here and then register + // each listener with it. In this case I want this to blow up horribly if + // any of the listeners stop working. systemd will restart it due to the + // socket activation at play. + // + // TL;DR: Let it crash, it will come back + for _, ln := range listeners { + go func(ln net.Listener) { + log.Printf("listening on %s", ln.Addr()) + log.Fatal(http.Serve(ln, mux)) + }(ln) + } + + for { + select {} + } +} diff --git a/cmd/nginx-auth/rpm/postrm.sh b/cmd/nginx-auth/rpm/postrm.sh index d8d36893f..3d0abfb19 100755 --- a/cmd/nginx-auth/rpm/postrm.sh +++ b/cmd/nginx-auth/rpm/postrm.sh @@ -1,9 +1,9 @@ -# $1 == 0 for uninstallation. -# $1 == 1 for removing old package during upgrade. - -systemctl daemon-reload >/dev/null 2>&1 || : -if [ $1 -ge 1 ] ; then - # Package upgrade, not uninstall - systemctl stop tailscale.nginx-auth.service >/dev/null 2>&1 || : - systemctl try-restart tailscale.nginx-auth.socket >/dev/null 2>&1 || : -fi +# $1 == 0 for uninstallation. +# $1 == 1 for removing old package during upgrade. + +systemctl daemon-reload >/dev/null 2>&1 || : +if [ $1 -ge 1 ] ; then + # Package upgrade, not uninstall + systemctl stop tailscale.nginx-auth.service >/dev/null 2>&1 || : + systemctl try-restart tailscale.nginx-auth.socket >/dev/null 2>&1 || : +fi diff --git a/cmd/nginx-auth/rpm/prerm.sh b/cmd/nginx-auth/rpm/prerm.sh index 2e47a53ed..1f198d829 100755 --- a/cmd/nginx-auth/rpm/prerm.sh +++ b/cmd/nginx-auth/rpm/prerm.sh @@ -1,9 +1,9 @@ -# $1 == 0 for uninstallation. -# $1 == 1 for removing old package during upgrade. - -if [ $1 -eq 0 ] ; then - # Package removal, not upgrade - systemctl --no-reload disable tailscale.nginx-auth.socket > /dev/null 2>&1 || : - systemctl stop tailscale.nginx-auth.socket > /dev/null 2>&1 || : - systemctl stop tailscale.nginx-auth.service > /dev/null 2>&1 || : -fi +# $1 == 0 for uninstallation. +# $1 == 1 for removing old package during upgrade. + +if [ $1 -eq 0 ] ; then + # Package removal, not upgrade + systemctl --no-reload disable tailscale.nginx-auth.socket > /dev/null 2>&1 || : + systemctl stop tailscale.nginx-auth.socket > /dev/null 2>&1 || : + systemctl stop tailscale.nginx-auth.service > /dev/null 2>&1 || : +fi diff --git a/cmd/nginx-auth/tailscale.nginx-auth.service b/cmd/nginx-auth/tailscale.nginx-auth.service index 8534e25c1..086f6c774 100644 --- a/cmd/nginx-auth/tailscale.nginx-auth.service +++ b/cmd/nginx-auth/tailscale.nginx-auth.service @@ -1,11 +1,11 @@ -[Unit] -Description=Tailscale NGINX Authentication service -After=nginx.service -Wants=nginx.service - -[Service] -ExecStart=/usr/sbin/tailscale.nginx-auth -DynamicUser=yes - -[Install] -WantedBy=default.target +[Unit] +Description=Tailscale NGINX Authentication service +After=nginx.service +Wants=nginx.service + +[Service] +ExecStart=/usr/sbin/tailscale.nginx-auth +DynamicUser=yes + +[Install] +WantedBy=default.target diff --git a/cmd/nginx-auth/tailscale.nginx-auth.socket b/cmd/nginx-auth/tailscale.nginx-auth.socket index 53e3e8d83..7e5641ff3 100644 --- a/cmd/nginx-auth/tailscale.nginx-auth.socket +++ b/cmd/nginx-auth/tailscale.nginx-auth.socket @@ -1,9 +1,9 @@ -[Unit] -Description=Tailscale NGINX Authentication socket -PartOf=tailscale.nginx-auth.service - -[Socket] -ListenStream=/var/run/tailscale.nginx-auth.sock - -[Install] +[Unit] +Description=Tailscale NGINX Authentication socket +PartOf=tailscale.nginx-auth.service + +[Socket] +ListenStream=/var/run/tailscale.nginx-auth.sock + +[Install] WantedBy=sockets.target \ No newline at end of file diff --git a/cmd/pgproxy/README.md b/cmd/pgproxy/README.md index a867ad8ca..2e013072a 100644 --- a/cmd/pgproxy/README.md +++ b/cmd/pgproxy/README.md @@ -1,42 +1,42 @@ -# pgproxy - -The pgproxy server is a proxy for the Postgres wire protocol. [Read -more in our blog -post](https://tailscale.com/blog/introducing-pgproxy/) about it! - -The proxy runs an in-process Tailscale instance, accepts postgres -client connections over Tailscale only, and proxies them to the -configured upstream postgres server. - -This proxy exists because postgres clients default to very insecure -connection settings: either they "prefer" but do not require TLS; or -they set sslmode=require, which merely requires that a TLS handshake -took place, but don't verify the server's TLS certificate or the -presented TLS hostname. In other words, sslmode=require enforces that -a TLS session is created, but that session can trivially be -machine-in-the-middled to steal credentials, data, inject malicious -queries, and so forth. - -Because this flaw is in the client's validation of the TLS session, -you have no way of reliably detecting the misconfiguration -server-side. You could fix the configuration of all the clients you -know of, but the default makes it very easy to accidentally regress. - -Instead of trying to verify client configuration over time, this proxy -removes the need for postgres clients to be configured correctly: the -upstream database is configured to only accept connections from the -proxy, and the proxy is only available to clients over Tailscale. - -Therefore, clients must use the proxy to connect to the database. The -client<>proxy connection is secured end-to-end by Tailscale, which the -proxy enforces by verifying that the connecting client is a known -current Tailscale peer. The proxy<>server connection is established by -the proxy itself, using strict TLS verification settings, and the -client is only allowed to communicate with the server once we've -established that the upstream connection is safe to use. - -A couple side benefits: because clients can only connect via -Tailscale, you can use Tailscale ACLs as an extra layer of defense on -top of the postgres user/password authentication. And, the proxy can -maintain an audit log of who connected to the database, complete with -the strongly authenticated Tailscale identity of the client. +# pgproxy + +The pgproxy server is a proxy for the Postgres wire protocol. [Read +more in our blog +post](https://tailscale.com/blog/introducing-pgproxy/) about it! + +The proxy runs an in-process Tailscale instance, accepts postgres +client connections over Tailscale only, and proxies them to the +configured upstream postgres server. + +This proxy exists because postgres clients default to very insecure +connection settings: either they "prefer" but do not require TLS; or +they set sslmode=require, which merely requires that a TLS handshake +took place, but don't verify the server's TLS certificate or the +presented TLS hostname. In other words, sslmode=require enforces that +a TLS session is created, but that session can trivially be +machine-in-the-middled to steal credentials, data, inject malicious +queries, and so forth. + +Because this flaw is in the client's validation of the TLS session, +you have no way of reliably detecting the misconfiguration +server-side. You could fix the configuration of all the clients you +know of, but the default makes it very easy to accidentally regress. + +Instead of trying to verify client configuration over time, this proxy +removes the need for postgres clients to be configured correctly: the +upstream database is configured to only accept connections from the +proxy, and the proxy is only available to clients over Tailscale. + +Therefore, clients must use the proxy to connect to the database. The +client<>proxy connection is secured end-to-end by Tailscale, which the +proxy enforces by verifying that the connecting client is a known +current Tailscale peer. The proxy<>server connection is established by +the proxy itself, using strict TLS verification settings, and the +client is only allowed to communicate with the server once we've +established that the upstream connection is safe to use. + +A couple side benefits: because clients can only connect via +Tailscale, you can use Tailscale ACLs as an extra layer of defense on +top of the postgres user/password authentication. And, the proxy can +maintain an audit log of who connected to the database, complete with +the strongly authenticated Tailscale identity of the client. diff --git a/cmd/printdep/printdep.go b/cmd/printdep/printdep.go index 0790a8b81..044283209 100644 --- a/cmd/printdep/printdep.go +++ b/cmd/printdep/printdep.go @@ -1,41 +1,41 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The printdep command is a build system tool for printing out information -// about dependencies. -package main - -import ( - "flag" - "fmt" - "log" - "runtime" - "strings" - - ts "tailscale.com" -) - -var ( - goToolchain = flag.Bool("go", false, "print the supported Go toolchain git hash (a github.com/tailscale/go commit)") - goToolchainURL = flag.Bool("go-url", false, "print the URL to the tarball of the Tailscale Go toolchain") - alpine = flag.Bool("alpine", false, "print the tag of alpine docker image") -) - -func main() { - flag.Parse() - if *alpine { - fmt.Println(strings.TrimSpace(ts.AlpineDockerTag)) - return - } - if *goToolchain { - fmt.Println(strings.TrimSpace(ts.GoToolchainRev)) - } - if *goToolchainURL { - switch runtime.GOOS { - case "linux", "darwin": - default: - log.Fatalf("unsupported GOOS %q", runtime.GOOS) - } - fmt.Printf("https://github.com/tailscale/go/releases/download/build-%s/%s-%s.tar.gz\n", strings.TrimSpace(ts.GoToolchainRev), runtime.GOOS, runtime.GOARCH) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The printdep command is a build system tool for printing out information +// about dependencies. +package main + +import ( + "flag" + "fmt" + "log" + "runtime" + "strings" + + ts "tailscale.com" +) + +var ( + goToolchain = flag.Bool("go", false, "print the supported Go toolchain git hash (a github.com/tailscale/go commit)") + goToolchainURL = flag.Bool("go-url", false, "print the URL to the tarball of the Tailscale Go toolchain") + alpine = flag.Bool("alpine", false, "print the tag of alpine docker image") +) + +func main() { + flag.Parse() + if *alpine { + fmt.Println(strings.TrimSpace(ts.AlpineDockerTag)) + return + } + if *goToolchain { + fmt.Println(strings.TrimSpace(ts.GoToolchainRev)) + } + if *goToolchainURL { + switch runtime.GOOS { + case "linux", "darwin": + default: + log.Fatalf("unsupported GOOS %q", runtime.GOOS) + } + fmt.Printf("https://github.com/tailscale/go/releases/download/build-%s/%s-%s.tar.gz\n", strings.TrimSpace(ts.GoToolchainRev), runtime.GOOS, runtime.GOARCH) + } +} diff --git a/cmd/sniproxy/.gitignore b/cmd/sniproxy/.gitignore index 0bca33912..b1399c881 100644 --- a/cmd/sniproxy/.gitignore +++ b/cmd/sniproxy/.gitignore @@ -1 +1 @@ -sniproxy +sniproxy diff --git a/cmd/sniproxy/handlers_test.go b/cmd/sniproxy/handlers_test.go index 8ec5b097c..4f9fc6a34 100644 --- a/cmd/sniproxy/handlers_test.go +++ b/cmd/sniproxy/handlers_test.go @@ -1,159 +1,159 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "bytes" - "context" - "encoding/hex" - "io" - "net" - "net/netip" - "strings" - "testing" - - "tailscale.com/net/memnet" -) - -func echoConnOnce(conn net.Conn) { - defer conn.Close() - - b := make([]byte, 256) - n, err := conn.Read(b) - if err != nil { - return - } - - if _, err := conn.Write(b[:n]); err != nil { - return - } -} - -func TestTCPRoundRobinHandler(t *testing.T) { - h := tcpRoundRobinHandler{ - To: []string{"yeet.com"}, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - if network != "tcp" { - t.Errorf("network = %s, want %s", network, "tcp") - } - if addr != "yeet.com:22" { - t.Errorf("addr = %s, want %s", addr, "yeet.com:22") - } - - c, s := memnet.NewConn("outbound", 1024) - go echoConnOnce(s) - return c, nil - }, - } - - cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:22"), 1024) - h.Handle(sSock) - - // Test data write and read, the other end will echo back - // a single stanza - want := "hello" - if _, err := io.WriteString(cSock, want); err != nil { - t.Fatal(err) - } - got := make([]byte, len(want)) - if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil { - t.Fatal(err) - } - if string(got) != want { - t.Errorf("got %q, want %q", got, want) - } - - // The other end closed the socket after the first echo, so - // any following read should error. - io.WriteString(cSock, "deadass heres some data on god fr") - if _, err := io.ReadAtLeast(cSock, got, len(got)); err == nil { - t.Error("read succeeded on closed socket") - } -} - -// Capture of first TCP data segment for a connection to https://pkgs.tailscale.com -const tlsStart = `45000239ff1840004006f9f5c0a801f2 -c726b5efcf9e01bbe803b21394e3b752 -801801f641dc00000101080ade3474f2 -2fb93ee71603010200010001fc030303 -c3acbd19d2624765bb19af4bce03365e -1d197f5bb939cdadeff26b0f8e7a0620 -295b04127b82bae46aac4ff58cffef25 -eba75a4b7a6de729532c411bd9dd0d2c -00203a3a130113021303c02bc02fc02c -c030cca9cca8c013c014009c009d002f -003501000193caca0000000a000a0008 -1a1a001d001700180010000e000c0268 -3208687474702f312e31002b0007062a -2a03040303ff01000100000d00120010 -04030804040105030805050108060601 -000b00020100002300000033002b0029 -1a1a000100001d0020d3c76bef062979 -a812ce935cfb4dbe6b3a84dc5ba9226f -23b0f34af9d1d03b4a001b0003020002 -00120000446900050003026832000000 -170015000012706b67732e7461696c73 -63616c652e636f6d002d000201010005 -00050100000000001700003a3a000100 -0015002d000000000000000000000000 -00000000000000000000000000000000 -00000000000000000000000000000000 -0000290094006f0069e76f2016f963ad -38c8632d1f240cd75e00e25fdef295d4 -7042b26f3a9a543b1c7dc74939d77803 -20527d423ff996997bda2c6383a14f49 -219eeef8a053e90a32228df37ddbe126 -eccf6b085c93890d08341d819aea6111 -0d909f4cd6b071d9ea40618e74588a33 -90d494bbb5c3002120d5a164a16c9724 -c9ef5e540d8d6f007789a7acf9f5f16f -bf6a1907a6782ed02b` - -func fakeSNIHeader() []byte { - b, err := hex.DecodeString(strings.Replace(tlsStart, "\n", "", -1)) - if err != nil { - panic(err) - } - return b[0x34:] // trim IP + TCP header -} - -func TestTCPSNIHandler(t *testing.T) { - h := tcpSNIHandler{ - Allowlist: []string{"pkgs.tailscale.com"}, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - if network != "tcp" { - t.Errorf("network = %s, want %s", network, "tcp") - } - if addr != "pkgs.tailscale.com:443" { - t.Errorf("addr = %s, want %s", addr, "pkgs.tailscale.com:443") - } - - c, s := memnet.NewConn("outbound", 1024) - go echoConnOnce(s) - return c, nil - }, - } - - cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:443"), 1024) - h.Handle(sSock) - - // Fake a TLS handshake record with an SNI in it. - if _, err := cSock.Write(fakeSNIHeader()); err != nil { - t.Fatal(err) - } - - // Test read, the other end will echo back - // a single stanza, which is at least the beginning of the SNI header. - want := fakeSNIHeader()[:5] - if _, err := cSock.Write(want); err != nil { - t.Fatal(err) - } - got := make([]byte, len(want)) - if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil { - t.Fatal(err) - } - if !bytes.Equal(got, want) { - t.Errorf("got %q, want %q", got, want) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "bytes" + "context" + "encoding/hex" + "io" + "net" + "net/netip" + "strings" + "testing" + + "tailscale.com/net/memnet" +) + +func echoConnOnce(conn net.Conn) { + defer conn.Close() + + b := make([]byte, 256) + n, err := conn.Read(b) + if err != nil { + return + } + + if _, err := conn.Write(b[:n]); err != nil { + return + } +} + +func TestTCPRoundRobinHandler(t *testing.T) { + h := tcpRoundRobinHandler{ + To: []string{"yeet.com"}, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if network != "tcp" { + t.Errorf("network = %s, want %s", network, "tcp") + } + if addr != "yeet.com:22" { + t.Errorf("addr = %s, want %s", addr, "yeet.com:22") + } + + c, s := memnet.NewConn("outbound", 1024) + go echoConnOnce(s) + return c, nil + }, + } + + cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:22"), 1024) + h.Handle(sSock) + + // Test data write and read, the other end will echo back + // a single stanza + want := "hello" + if _, err := io.WriteString(cSock, want); err != nil { + t.Fatal(err) + } + got := make([]byte, len(want)) + if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil { + t.Fatal(err) + } + if string(got) != want { + t.Errorf("got %q, want %q", got, want) + } + + // The other end closed the socket after the first echo, so + // any following read should error. + io.WriteString(cSock, "deadass heres some data on god fr") + if _, err := io.ReadAtLeast(cSock, got, len(got)); err == nil { + t.Error("read succeeded on closed socket") + } +} + +// Capture of first TCP data segment for a connection to https://pkgs.tailscale.com +const tlsStart = `45000239ff1840004006f9f5c0a801f2 +c726b5efcf9e01bbe803b21394e3b752 +801801f641dc00000101080ade3474f2 +2fb93ee71603010200010001fc030303 +c3acbd19d2624765bb19af4bce03365e +1d197f5bb939cdadeff26b0f8e7a0620 +295b04127b82bae46aac4ff58cffef25 +eba75a4b7a6de729532c411bd9dd0d2c +00203a3a130113021303c02bc02fc02c +c030cca9cca8c013c014009c009d002f +003501000193caca0000000a000a0008 +1a1a001d001700180010000e000c0268 +3208687474702f312e31002b0007062a +2a03040303ff01000100000d00120010 +04030804040105030805050108060601 +000b00020100002300000033002b0029 +1a1a000100001d0020d3c76bef062979 +a812ce935cfb4dbe6b3a84dc5ba9226f +23b0f34af9d1d03b4a001b0003020002 +00120000446900050003026832000000 +170015000012706b67732e7461696c73 +63616c652e636f6d002d000201010005 +00050100000000001700003a3a000100 +0015002d000000000000000000000000 +00000000000000000000000000000000 +00000000000000000000000000000000 +0000290094006f0069e76f2016f963ad +38c8632d1f240cd75e00e25fdef295d4 +7042b26f3a9a543b1c7dc74939d77803 +20527d423ff996997bda2c6383a14f49 +219eeef8a053e90a32228df37ddbe126 +eccf6b085c93890d08341d819aea6111 +0d909f4cd6b071d9ea40618e74588a33 +90d494bbb5c3002120d5a164a16c9724 +c9ef5e540d8d6f007789a7acf9f5f16f +bf6a1907a6782ed02b` + +func fakeSNIHeader() []byte { + b, err := hex.DecodeString(strings.Replace(tlsStart, "\n", "", -1)) + if err != nil { + panic(err) + } + return b[0x34:] // trim IP + TCP header +} + +func TestTCPSNIHandler(t *testing.T) { + h := tcpSNIHandler{ + Allowlist: []string{"pkgs.tailscale.com"}, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if network != "tcp" { + t.Errorf("network = %s, want %s", network, "tcp") + } + if addr != "pkgs.tailscale.com:443" { + t.Errorf("addr = %s, want %s", addr, "pkgs.tailscale.com:443") + } + + c, s := memnet.NewConn("outbound", 1024) + go echoConnOnce(s) + return c, nil + }, + } + + cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:443"), 1024) + h.Handle(sSock) + + // Fake a TLS handshake record with an SNI in it. + if _, err := cSock.Write(fakeSNIHeader()); err != nil { + t.Fatal(err) + } + + // Test read, the other end will echo back + // a single stanza, which is at least the beginning of the SNI header. + want := fakeSNIHeader()[:5] + if _, err := cSock.Write(want); err != nil { + t.Fatal(err) + } + got := make([]byte, len(want)) + if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, want) { + t.Errorf("got %q, want %q", got, want) + } +} diff --git a/cmd/sniproxy/server.go b/cmd/sniproxy/server.go index c89420661..b322b6f4b 100644 --- a/cmd/sniproxy/server.go +++ b/cmd/sniproxy/server.go @@ -1,327 +1,327 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "expvar" - "log" - "net" - "net/netip" - "sync" - "time" - - "golang.org/x/net/dns/dnsmessage" - "tailscale.com/metrics" - "tailscale.com/tailcfg" - "tailscale.com/types/appctype" - "tailscale.com/types/ipproto" - "tailscale.com/types/nettype" - "tailscale.com/util/clientmetric" - "tailscale.com/util/mak" -) - -var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") - -// target describes the predicates which route some inbound -// traffic to the app connector to a specific handler. -type target struct { - Dest netip.Prefix - Matching tailcfg.ProtoPortRange -} - -// Server implements an App Connector as expressed in sniproxy. -type Server struct { - mu sync.RWMutex // mu guards following fields - connectors map[appctype.ConfigID]connector -} - -type appcMetrics struct { - dnsResponses expvar.Int - dnsFailures expvar.Int - tcpConns expvar.Int - sniConns expvar.Int - unhandledConns expvar.Int -} - -var getMetrics = sync.OnceValue[*appcMetrics](func() *appcMetrics { - m := appcMetrics{} - - stats := new(metrics.Set) - stats.Set("tls_sessions", &m.sniConns) - clientmetric.NewCounterFunc("sniproxy_tls_sessions", m.sniConns.Value) - stats.Set("tcp_sessions", &m.tcpConns) - clientmetric.NewCounterFunc("sniproxy_tcp_sessions", m.tcpConns.Value) - stats.Set("dns_responses", &m.dnsResponses) - clientmetric.NewCounterFunc("sniproxy_dns_responses", m.dnsResponses.Value) - stats.Set("dns_failed", &m.dnsFailures) - clientmetric.NewCounterFunc("sniproxy_dns_failed", m.dnsFailures.Value) - expvar.Publish("sniproxy", stats) - - return &m -}) - -// Configure applies the provided configuration to the app connector. -func (s *Server) Configure(cfg *appctype.AppConnectorConfig) { - s.mu.Lock() - defer s.mu.Unlock() - s.connectors = makeConnectorsFromConfig(cfg) - log.Printf("installed app connector config: %+v", s.connectors) -} - -// HandleTCPFlow implements tsnet.FallbackTCPHandler. -func (s *Server) HandleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { - m := getMetrics() - s.mu.RLock() - defer s.mu.RUnlock() - - for _, c := range s.connectors { - if handler, intercept := c.handleTCPFlow(src, dst, m); intercept { - return handler, intercept - } - } - - return nil, false -} - -// HandleDNS handles a DNS request to the app connector. -func (s *Server) HandleDNS(c nettype.ConnPacketConn) { - defer c.Close() - c.SetReadDeadline(time.Now().Add(5 * time.Second)) - m := getMetrics() - - buf := make([]byte, 1500) - n, err := c.Read(buf) - if err != nil { - log.Printf("HandleDNS: read failed: %v\n ", err) - m.dnsFailures.Add(1) - return - } - - addrPortStr := c.LocalAddr().String() - host, _, err := net.SplitHostPort(addrPortStr) - if err != nil { - log.Printf("HandleDNS: bogus addrPort %q", addrPortStr) - m.dnsFailures.Add(1) - return - } - localAddr, err := netip.ParseAddr(host) - if err != nil { - log.Printf("HandleDNS: bogus local address %q", host) - m.dnsFailures.Add(1) - return - } - - var msg dnsmessage.Message - err = msg.Unpack(buf[:n]) - if err != nil { - log.Printf("HandleDNS: dnsmessage unpack failed: %v\n ", err) - m.dnsFailures.Add(1) - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - for _, connector := range s.connectors { - resp, err := connector.handleDNS(&msg, localAddr) - if err != nil { - log.Printf("HandleDNS: connector handling failed: %v\n", err) - m.dnsFailures.Add(1) - return - } - if len(resp) > 0 { - // This connector handled the DNS request - _, err = c.Write(resp) - if err != nil { - log.Printf("HandleDNS: write failed: %v\n", err) - m.dnsFailures.Add(1) - return - } - - m.dnsResponses.Add(1) - return - } - } -} - -// connector describes a logical collection of -// services which need to be proxied. -type connector struct { - Handlers map[target]handler -} - -// handleTCPFlow implements tsnet.FallbackTCPHandler. -func (c *connector) handleTCPFlow(src, dst netip.AddrPort, m *appcMetrics) (handler func(net.Conn), intercept bool) { - for t, h := range c.Handlers { - if t.Matching.Proto != 0 && t.Matching.Proto != int(ipproto.TCP) { - continue - } - if !t.Dest.Contains(dst.Addr()) { - continue - } - if !t.Matching.Ports.Contains(dst.Port()) { - continue - } - - switch h.(type) { - case *tcpSNIHandler: - m.sniConns.Add(1) - case *tcpRoundRobinHandler: - m.tcpConns.Add(1) - default: - log.Printf("handleTCPFlow: unhandled handler type %T", h) - } - - return h.Handle, true - } - - m.unhandledConns.Add(1) - return nil, false -} - -// handleDNS returns the DNS response to the given query. If this -// connector is unable to handle the request, nil is returned. -func (c *connector) handleDNS(req *dnsmessage.Message, localAddr netip.Addr) (response []byte, err error) { - for t, h := range c.Handlers { - if t.Dest.Contains(localAddr) { - return makeDNSResponse(req, h.ReachableOn()) - } - } - - // Did not match, signal 'not handled' to caller - return nil, nil -} - -func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (response []byte, err error) { - resp := dnsmessage.NewBuilder(response, - dnsmessage.Header{ - ID: req.Header.ID, - Response: true, - Authoritative: true, - }) - resp.EnableCompression() - - if len(req.Questions) == 0 { - response, _ = resp.Finish() - return response, nil - } - q := req.Questions[0] - err = resp.StartQuestions() - if err != nil { - return - } - resp.Question(q) - - err = resp.StartAnswers() - if err != nil { - return - } - - switch q.Type { - case dnsmessage.TypeAAAA: - for _, ip := range reachableIPs { - if ip.Is6() { - err = resp.AAAAResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.AAAAResource{AAAA: ip.As16()}, - ) - } - } - - case dnsmessage.TypeA: - for _, ip := range reachableIPs { - if ip.Is4() { - err = resp.AResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.AResource{A: ip.As4()}, - ) - } - } - - case dnsmessage.TypeSOA: - err = resp.SOAResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600, - Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60}, - ) - case dnsmessage.TypeNS: - err = resp.NSResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.NSResource{NS: tsMBox}, - ) - } - - if err != nil { - return nil, err - } - return resp.Finish() -} - -type handler interface { - // Handle handles the given socket. - Handle(c net.Conn) - - // ReachableOn returns the IP addresses this handler is reachable on. - ReachableOn() []netip.Addr -} - -func installDNATHandler(d *appctype.DNATConfig, out *connector) { - // These handlers don't actually do DNAT, they just - // proxy the data over the connection. - var dialer net.Dialer - dialer.Timeout = 5 * time.Second - h := tcpRoundRobinHandler{ - To: d.To, - DialContext: dialer.DialContext, - ReachableIPs: d.Addrs, - } - - for _, addr := range d.Addrs { - for _, protoPort := range d.IP { - t := target{ - Dest: netip.PrefixFrom(addr, addr.BitLen()), - Matching: protoPort, - } - - mak.Set(&out.Handlers, t, handler(&h)) - } - } -} - -func installSNIHandler(c *appctype.SNIProxyConfig, out *connector) { - var dialer net.Dialer - dialer.Timeout = 5 * time.Second - h := tcpSNIHandler{ - Allowlist: c.AllowedDomains, - DialContext: dialer.DialContext, - ReachableIPs: c.Addrs, - } - - for _, addr := range c.Addrs { - for _, protoPort := range c.IP { - t := target{ - Dest: netip.PrefixFrom(addr, addr.BitLen()), - Matching: protoPort, - } - - mak.Set(&out.Handlers, t, handler(&h)) - } - } -} - -func makeConnectorsFromConfig(cfg *appctype.AppConnectorConfig) map[appctype.ConfigID]connector { - var connectors map[appctype.ConfigID]connector - - for cID, d := range cfg.DNAT { - c := connectors[cID] - installDNATHandler(&d, &c) - mak.Set(&connectors, cID, c) - } - for cID, d := range cfg.SNIProxy { - c := connectors[cID] - installSNIHandler(&d, &c) - mak.Set(&connectors, cID, c) - } - - return connectors -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "expvar" + "log" + "net" + "net/netip" + "sync" + "time" + + "golang.org/x/net/dns/dnsmessage" + "tailscale.com/metrics" + "tailscale.com/tailcfg" + "tailscale.com/types/appctype" + "tailscale.com/types/ipproto" + "tailscale.com/types/nettype" + "tailscale.com/util/clientmetric" + "tailscale.com/util/mak" +) + +var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") + +// target describes the predicates which route some inbound +// traffic to the app connector to a specific handler. +type target struct { + Dest netip.Prefix + Matching tailcfg.ProtoPortRange +} + +// Server implements an App Connector as expressed in sniproxy. +type Server struct { + mu sync.RWMutex // mu guards following fields + connectors map[appctype.ConfigID]connector +} + +type appcMetrics struct { + dnsResponses expvar.Int + dnsFailures expvar.Int + tcpConns expvar.Int + sniConns expvar.Int + unhandledConns expvar.Int +} + +var getMetrics = sync.OnceValue[*appcMetrics](func() *appcMetrics { + m := appcMetrics{} + + stats := new(metrics.Set) + stats.Set("tls_sessions", &m.sniConns) + clientmetric.NewCounterFunc("sniproxy_tls_sessions", m.sniConns.Value) + stats.Set("tcp_sessions", &m.tcpConns) + clientmetric.NewCounterFunc("sniproxy_tcp_sessions", m.tcpConns.Value) + stats.Set("dns_responses", &m.dnsResponses) + clientmetric.NewCounterFunc("sniproxy_dns_responses", m.dnsResponses.Value) + stats.Set("dns_failed", &m.dnsFailures) + clientmetric.NewCounterFunc("sniproxy_dns_failed", m.dnsFailures.Value) + expvar.Publish("sniproxy", stats) + + return &m +}) + +// Configure applies the provided configuration to the app connector. +func (s *Server) Configure(cfg *appctype.AppConnectorConfig) { + s.mu.Lock() + defer s.mu.Unlock() + s.connectors = makeConnectorsFromConfig(cfg) + log.Printf("installed app connector config: %+v", s.connectors) +} + +// HandleTCPFlow implements tsnet.FallbackTCPHandler. +func (s *Server) HandleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { + m := getMetrics() + s.mu.RLock() + defer s.mu.RUnlock() + + for _, c := range s.connectors { + if handler, intercept := c.handleTCPFlow(src, dst, m); intercept { + return handler, intercept + } + } + + return nil, false +} + +// HandleDNS handles a DNS request to the app connector. +func (s *Server) HandleDNS(c nettype.ConnPacketConn) { + defer c.Close() + c.SetReadDeadline(time.Now().Add(5 * time.Second)) + m := getMetrics() + + buf := make([]byte, 1500) + n, err := c.Read(buf) + if err != nil { + log.Printf("HandleDNS: read failed: %v\n ", err) + m.dnsFailures.Add(1) + return + } + + addrPortStr := c.LocalAddr().String() + host, _, err := net.SplitHostPort(addrPortStr) + if err != nil { + log.Printf("HandleDNS: bogus addrPort %q", addrPortStr) + m.dnsFailures.Add(1) + return + } + localAddr, err := netip.ParseAddr(host) + if err != nil { + log.Printf("HandleDNS: bogus local address %q", host) + m.dnsFailures.Add(1) + return + } + + var msg dnsmessage.Message + err = msg.Unpack(buf[:n]) + if err != nil { + log.Printf("HandleDNS: dnsmessage unpack failed: %v\n ", err) + m.dnsFailures.Add(1) + return + } + + s.mu.RLock() + defer s.mu.RUnlock() + for _, connector := range s.connectors { + resp, err := connector.handleDNS(&msg, localAddr) + if err != nil { + log.Printf("HandleDNS: connector handling failed: %v\n", err) + m.dnsFailures.Add(1) + return + } + if len(resp) > 0 { + // This connector handled the DNS request + _, err = c.Write(resp) + if err != nil { + log.Printf("HandleDNS: write failed: %v\n", err) + m.dnsFailures.Add(1) + return + } + + m.dnsResponses.Add(1) + return + } + } +} + +// connector describes a logical collection of +// services which need to be proxied. +type connector struct { + Handlers map[target]handler +} + +// handleTCPFlow implements tsnet.FallbackTCPHandler. +func (c *connector) handleTCPFlow(src, dst netip.AddrPort, m *appcMetrics) (handler func(net.Conn), intercept bool) { + for t, h := range c.Handlers { + if t.Matching.Proto != 0 && t.Matching.Proto != int(ipproto.TCP) { + continue + } + if !t.Dest.Contains(dst.Addr()) { + continue + } + if !t.Matching.Ports.Contains(dst.Port()) { + continue + } + + switch h.(type) { + case *tcpSNIHandler: + m.sniConns.Add(1) + case *tcpRoundRobinHandler: + m.tcpConns.Add(1) + default: + log.Printf("handleTCPFlow: unhandled handler type %T", h) + } + + return h.Handle, true + } + + m.unhandledConns.Add(1) + return nil, false +} + +// handleDNS returns the DNS response to the given query. If this +// connector is unable to handle the request, nil is returned. +func (c *connector) handleDNS(req *dnsmessage.Message, localAddr netip.Addr) (response []byte, err error) { + for t, h := range c.Handlers { + if t.Dest.Contains(localAddr) { + return makeDNSResponse(req, h.ReachableOn()) + } + } + + // Did not match, signal 'not handled' to caller + return nil, nil +} + +func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (response []byte, err error) { + resp := dnsmessage.NewBuilder(response, + dnsmessage.Header{ + ID: req.Header.ID, + Response: true, + Authoritative: true, + }) + resp.EnableCompression() + + if len(req.Questions) == 0 { + response, _ = resp.Finish() + return response, nil + } + q := req.Questions[0] + err = resp.StartQuestions() + if err != nil { + return + } + resp.Question(q) + + err = resp.StartAnswers() + if err != nil { + return + } + + switch q.Type { + case dnsmessage.TypeAAAA: + for _, ip := range reachableIPs { + if ip.Is6() { + err = resp.AAAAResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.AAAAResource{AAAA: ip.As16()}, + ) + } + } + + case dnsmessage.TypeA: + for _, ip := range reachableIPs { + if ip.Is4() { + err = resp.AResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.AResource{A: ip.As4()}, + ) + } + } + + case dnsmessage.TypeSOA: + err = resp.SOAResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600, + Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60}, + ) + case dnsmessage.TypeNS: + err = resp.NSResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.NSResource{NS: tsMBox}, + ) + } + + if err != nil { + return nil, err + } + return resp.Finish() +} + +type handler interface { + // Handle handles the given socket. + Handle(c net.Conn) + + // ReachableOn returns the IP addresses this handler is reachable on. + ReachableOn() []netip.Addr +} + +func installDNATHandler(d *appctype.DNATConfig, out *connector) { + // These handlers don't actually do DNAT, they just + // proxy the data over the connection. + var dialer net.Dialer + dialer.Timeout = 5 * time.Second + h := tcpRoundRobinHandler{ + To: d.To, + DialContext: dialer.DialContext, + ReachableIPs: d.Addrs, + } + + for _, addr := range d.Addrs { + for _, protoPort := range d.IP { + t := target{ + Dest: netip.PrefixFrom(addr, addr.BitLen()), + Matching: protoPort, + } + + mak.Set(&out.Handlers, t, handler(&h)) + } + } +} + +func installSNIHandler(c *appctype.SNIProxyConfig, out *connector) { + var dialer net.Dialer + dialer.Timeout = 5 * time.Second + h := tcpSNIHandler{ + Allowlist: c.AllowedDomains, + DialContext: dialer.DialContext, + ReachableIPs: c.Addrs, + } + + for _, addr := range c.Addrs { + for _, protoPort := range c.IP { + t := target{ + Dest: netip.PrefixFrom(addr, addr.BitLen()), + Matching: protoPort, + } + + mak.Set(&out.Handlers, t, handler(&h)) + } + } +} + +func makeConnectorsFromConfig(cfg *appctype.AppConnectorConfig) map[appctype.ConfigID]connector { + var connectors map[appctype.ConfigID]connector + + for cID, d := range cfg.DNAT { + c := connectors[cID] + installDNATHandler(&d, &c) + mak.Set(&connectors, cID, c) + } + for cID, d := range cfg.SNIProxy { + c := connectors[cID] + installSNIHandler(&d, &c) + mak.Set(&connectors, cID, c) + } + + return connectors +} diff --git a/cmd/sniproxy/server_test.go b/cmd/sniproxy/server_test.go index 2a51c874c..d56f2aa75 100644 --- a/cmd/sniproxy/server_test.go +++ b/cmd/sniproxy/server_test.go @@ -1,95 +1,95 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "net/netip" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "tailscale.com/tailcfg" - "tailscale.com/types/appctype" -) - -func TestMakeConnectorsFromConfig(t *testing.T) { - tcs := []struct { - name string - input *appctype.AppConnectorConfig - want map[appctype.ConfigID]connector - }{ - { - "empty", - &appctype.AppConnectorConfig{}, - nil, - }, - { - "DNAT", - &appctype.AppConnectorConfig{ - DNAT: map[appctype.ConfigID]appctype.DNATConfig{ - "swiggity_swooty": { - Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, - To: []string{"example.org"}, - IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, - }, - }, - }, - map[appctype.ConfigID]connector{ - "swiggity_swooty": { - Handlers: map[target]handler{ - { - Dest: netip.MustParsePrefix("100.64.0.1/32"), - Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, - }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, - { - Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), - Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, - }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, - }, - }, - }, - }, - { - "SNIProxy", - &appctype.AppConnectorConfig{ - SNIProxy: map[appctype.ConfigID]appctype.SNIProxyConfig{ - "swiggity_swooty": { - Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, - AllowedDomains: []string{"example.org"}, - IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, - }, - }, - }, - map[appctype.ConfigID]connector{ - "swiggity_swooty": { - Handlers: map[target]handler{ - { - Dest: netip.MustParsePrefix("100.64.0.1/32"), - Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, - }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, - { - Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), - Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, - }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, - }, - }, - }, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - connectors := makeConnectorsFromConfig(tc.input) - - if diff := cmp.Diff(connectors, tc.want, - cmpopts.IgnoreFields(tcpRoundRobinHandler{}, "DialContext"), - cmpopts.IgnoreFields(tcpSNIHandler{}, "DialContext"), - cmp.Comparer(func(x, y netip.Addr) bool { - return x == y - })); diff != "" { - t.Fatalf("mismatch (-want +got):\n%s", diff) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "net/netip" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "tailscale.com/tailcfg" + "tailscale.com/types/appctype" +) + +func TestMakeConnectorsFromConfig(t *testing.T) { + tcs := []struct { + name string + input *appctype.AppConnectorConfig + want map[appctype.ConfigID]connector + }{ + { + "empty", + &appctype.AppConnectorConfig{}, + nil, + }, + { + "DNAT", + &appctype.AppConnectorConfig{ + DNAT: map[appctype.ConfigID]appctype.DNATConfig{ + "swiggity_swooty": { + Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, + To: []string{"example.org"}, + IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, + }, + }, + }, + map[appctype.ConfigID]connector{ + "swiggity_swooty": { + Handlers: map[target]handler{ + { + Dest: netip.MustParsePrefix("100.64.0.1/32"), + Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, + }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, + { + Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, + }: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, + }, + }, + }, + }, + { + "SNIProxy", + &appctype.AppConnectorConfig{ + SNIProxy: map[appctype.ConfigID]appctype.SNIProxyConfig{ + "swiggity_swooty": { + Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, + AllowedDomains: []string{"example.org"}, + IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, + }, + }, + }, + map[appctype.ConfigID]connector{ + "swiggity_swooty": { + Handlers: map[target]handler{ + { + Dest: netip.MustParsePrefix("100.64.0.1/32"), + Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, + }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, + { + Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}, + }: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}}, + }, + }, + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + connectors := makeConnectorsFromConfig(tc.input) + + if diff := cmp.Diff(connectors, tc.want, + cmpopts.IgnoreFields(tcpRoundRobinHandler{}, "DialContext"), + cmpopts.IgnoreFields(tcpSNIHandler{}, "DialContext"), + cmp.Comparer(func(x, y netip.Addr) bool { + return x == y + })); diff != "" { + t.Fatalf("mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/cmd/sniproxy/sniproxy.go b/cmd/sniproxy/sniproxy.go index c048c8e7e..fa83aaf4a 100644 --- a/cmd/sniproxy/sniproxy.go +++ b/cmd/sniproxy/sniproxy.go @@ -1,291 +1,291 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The sniproxy is an outbound SNI proxy. It receives TLS connections over -// Tailscale on one or more TCP ports and sends them out to the same SNI -// hostname & port on the internet. It can optionally forward one or more -// TCP ports to a specific destination. It only does TCP. -package main - -import ( - "context" - "errors" - "flag" - "fmt" - "log" - "net" - "net/http" - "net/netip" - "os" - "sort" - "strconv" - "strings" - - "github.com/peterbourgon/ff/v3" - "tailscale.com/client/tailscale" - "tailscale.com/hostinfo" - "tailscale.com/ipn" - "tailscale.com/tailcfg" - "tailscale.com/tsnet" - "tailscale.com/tsweb" - "tailscale.com/types/appctype" - "tailscale.com/types/ipproto" - "tailscale.com/types/nettype" - "tailscale.com/util/mak" -) - -const configCapKey = "tailscale.com/sniproxy" - -// portForward is the state for a single port forwarding entry, as passed to the --forward flag. -type portForward struct { - Port int - Proto string - Destination string -} - -// parseForward takes a proto/port/destination tuple as an input, as would be passed -// to the --forward command line flag, and returns a *portForward struct of those parameters. -func parseForward(value string) (*portForward, error) { - parts := strings.Split(value, "/") - if len(parts) != 3 { - return nil, errors.New("cannot parse: " + value) - } - - proto := parts[0] - if proto != "tcp" { - return nil, errors.New("unsupported forwarding protocol: " + proto) - } - port, err := strconv.ParseUint(parts[1], 10, 16) - if err != nil { - return nil, errors.New("bad forwarding port: " + parts[1]) - } - host := parts[2] - if host == "" { - return nil, errors.New("bad destination: " + value) - } - - return &portForward{Port: int(port), Proto: proto, Destination: host}, nil -} - -func main() { - // Parse flags - fs := flag.NewFlagSet("sniproxy", flag.ContinueOnError) - var ( - ports = fs.String("ports", "443", "comma-separated list of ports to proxy") - forwards = fs.String("forwards", "", "comma-separated list of ports to transparently forward, protocol/number/destination. For example, --forwards=tcp/22/github.com,tcp/5432/sql.example.com") - wgPort = fs.Int("wg-listen-port", 0, "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select") - promoteHTTPS = fs.Bool("promote-https", true, "promote HTTP to HTTPS") - debugPort = fs.Int("debug-port", 8893, "Listening port for debug/metrics endpoint") - hostname = fs.String("hostname", "", "Hostname to register the service under") - ) - err := ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_APPC")) - if err != nil { - log.Fatal("ff.Parse") - } - - var ts tsnet.Server - defer ts.Close() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - run(ctx, &ts, *wgPort, *hostname, *promoteHTTPS, *debugPort, *ports, *forwards) -} - -// run actually runs the sniproxy. Its separate from main() to assist in testing. -func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, promoteHTTPS bool, debugPort int, ports, forwards string) { - // Wire up Tailscale node + app connector server - hostinfo.SetApp("sniproxy") - var s sniproxy - s.ts = ts - - s.ts.Port = uint16(wgPort) - s.ts.Hostname = hostname - - lc, err := s.ts.LocalClient() - if err != nil { - log.Fatalf("LocalClient() failed: %v", err) - } - s.lc = lc - s.ts.RegisterFallbackTCPHandler(s.srv.HandleTCPFlow) - - // Start special-purpose listeners: dns, http promotion, debug server - ln, err := s.ts.Listen("udp", ":53") - if err != nil { - log.Fatalf("failed listening on port 53: %v", err) - } - defer ln.Close() - go s.serveDNS(ln) - if promoteHTTPS { - ln, err := s.ts.Listen("tcp", ":80") - if err != nil { - log.Fatalf("failed listening on port 80: %v", err) - } - defer ln.Close() - log.Printf("Promoting HTTP to HTTPS ...") - go s.promoteHTTPS(ln) - } - if debugPort != 0 { - mux := http.NewServeMux() - tsweb.Debugger(mux) - dln, err := s.ts.Listen("tcp", fmt.Sprintf(":%d", debugPort)) - if err != nil { - log.Fatalf("failed listening on debug port: %v", err) - } - defer dln.Close() - go func() { - log.Fatalf("debug serve: %v", http.Serve(dln, mux)) - }() - } - - // Finally, start mainloop to configure app connector based on information - // in the netmap. - // We set the NotifyInitialNetMap flag so we will always get woken with the - // current netmap, before only being woken on changes. - bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap|ipn.NotifyNoPrivateKeys) - if err != nil { - log.Fatalf("watching IPN bus: %v", err) - } - defer bus.Close() - for { - msg, err := bus.Next() - if err != nil { - if errors.Is(err, context.Canceled) { - return - } - log.Fatalf("reading IPN bus: %v", err) - } - - // NetMap contains app-connector configuration - if nm := msg.NetMap; nm != nil && nm.SelfNode.Valid() { - sn := nm.SelfNode.AsStruct() - - var c appctype.AppConnectorConfig - nmConf, err := tailcfg.UnmarshalNodeCapJSON[appctype.AppConnectorConfig](sn.CapMap, configCapKey) - if err != nil { - log.Printf("failed to read app connector configuration from coordination server: %v", err) - } else if len(nmConf) > 0 { - c = nmConf[0] - } - - if c.AdvertiseRoutes { - if err := s.advertiseRoutesFromConfig(ctx, &c); err != nil { - log.Printf("failed to advertise routes: %v", err) - } - } - - // Backwards compatibility: combine any configuration from control with flags specified - // on the command line. This is intentionally done after we advertise any routes - // because its never correct to advertise the nodes native IP addresses. - s.mergeConfigFromFlags(&c, ports, forwards) - s.srv.Configure(&c) - } - } -} - -type sniproxy struct { - srv Server - ts *tsnet.Server - lc *tailscale.LocalClient -} - -func (s *sniproxy) advertiseRoutesFromConfig(ctx context.Context, c *appctype.AppConnectorConfig) error { - // Collect the set of addresses to advertise, using a map - // to avoid duplicate entries. - addrs := map[netip.Addr]struct{}{} - for _, c := range c.SNIProxy { - for _, ip := range c.Addrs { - addrs[ip] = struct{}{} - } - } - for _, c := range c.DNAT { - for _, ip := range c.Addrs { - addrs[ip] = struct{}{} - } - } - - var routes []netip.Prefix - for a := range addrs { - routes = append(routes, netip.PrefixFrom(a, a.BitLen())) - } - sort.SliceStable(routes, func(i, j int) bool { - return routes[i].Addr().Less(routes[j].Addr()) // determinism r us - }) - - _, err := s.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ - Prefs: ipn.Prefs{ - AdvertiseRoutes: routes, - }, - AdvertiseRoutesSet: true, - }) - return err -} - -func (s *sniproxy) mergeConfigFromFlags(out *appctype.AppConnectorConfig, ports, forwards string) { - ip4, ip6 := s.ts.TailscaleIPs() - - sniConfigFromFlags := appctype.SNIProxyConfig{ - Addrs: []netip.Addr{ip4, ip6}, - } - if ports != "" { - for _, portStr := range strings.Split(ports, ",") { - port, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - log.Fatalf("invalid port: %s", portStr) - } - sniConfigFromFlags.IP = append(sniConfigFromFlags.IP, tailcfg.ProtoPortRange{ - Proto: int(ipproto.TCP), - Ports: tailcfg.PortRange{First: uint16(port), Last: uint16(port)}, - }) - } - } - - var forwardConfigFromFlags []appctype.DNATConfig - for _, forwStr := range strings.Split(forwards, ",") { - if forwStr == "" { - continue - } - forw, err := parseForward(forwStr) - if err != nil { - log.Printf("invalid forwarding spec: %v", err) - continue - } - - forwardConfigFromFlags = append(forwardConfigFromFlags, appctype.DNATConfig{ - Addrs: []netip.Addr{ip4, ip6}, - To: []string{forw.Destination}, - IP: []tailcfg.ProtoPortRange{ - { - Proto: int(ipproto.TCP), - Ports: tailcfg.PortRange{First: uint16(forw.Port), Last: uint16(forw.Port)}, - }, - }, - }) - } - - if len(forwardConfigFromFlags) == 0 && len(sniConfigFromFlags.IP) == 0 { - return // no config specified on the command line - } - - mak.Set(&out.SNIProxy, "flags", sniConfigFromFlags) - for i, forward := range forwardConfigFromFlags { - mak.Set(&out.DNAT, appctype.ConfigID(fmt.Sprintf("flags_%d", i)), forward) - } -} - -func (s *sniproxy) serveDNS(ln net.Listener) { - for { - c, err := ln.Accept() - if err != nil { - log.Printf("serveDNS accept: %v", err) - return - } - go s.srv.HandleDNS(c.(nettype.ConnPacketConn)) - } -} - -func (s *sniproxy) promoteHTTPS(ln net.Listener) { - err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound) - })) - log.Fatalf("promoteHTTPS http.Serve: %v", err) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The sniproxy is an outbound SNI proxy. It receives TLS connections over +// Tailscale on one or more TCP ports and sends them out to the same SNI +// hostname & port on the internet. It can optionally forward one or more +// TCP ports to a specific destination. It only does TCP. +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "net" + "net/http" + "net/netip" + "os" + "sort" + "strconv" + "strings" + + "github.com/peterbourgon/ff/v3" + "tailscale.com/client/tailscale" + "tailscale.com/hostinfo" + "tailscale.com/ipn" + "tailscale.com/tailcfg" + "tailscale.com/tsnet" + "tailscale.com/tsweb" + "tailscale.com/types/appctype" + "tailscale.com/types/ipproto" + "tailscale.com/types/nettype" + "tailscale.com/util/mak" +) + +const configCapKey = "tailscale.com/sniproxy" + +// portForward is the state for a single port forwarding entry, as passed to the --forward flag. +type portForward struct { + Port int + Proto string + Destination string +} + +// parseForward takes a proto/port/destination tuple as an input, as would be passed +// to the --forward command line flag, and returns a *portForward struct of those parameters. +func parseForward(value string) (*portForward, error) { + parts := strings.Split(value, "/") + if len(parts) != 3 { + return nil, errors.New("cannot parse: " + value) + } + + proto := parts[0] + if proto != "tcp" { + return nil, errors.New("unsupported forwarding protocol: " + proto) + } + port, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return nil, errors.New("bad forwarding port: " + parts[1]) + } + host := parts[2] + if host == "" { + return nil, errors.New("bad destination: " + value) + } + + return &portForward{Port: int(port), Proto: proto, Destination: host}, nil +} + +func main() { + // Parse flags + fs := flag.NewFlagSet("sniproxy", flag.ContinueOnError) + var ( + ports = fs.String("ports", "443", "comma-separated list of ports to proxy") + forwards = fs.String("forwards", "", "comma-separated list of ports to transparently forward, protocol/number/destination. For example, --forwards=tcp/22/github.com,tcp/5432/sql.example.com") + wgPort = fs.Int("wg-listen-port", 0, "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select") + promoteHTTPS = fs.Bool("promote-https", true, "promote HTTP to HTTPS") + debugPort = fs.Int("debug-port", 8893, "Listening port for debug/metrics endpoint") + hostname = fs.String("hostname", "", "Hostname to register the service under") + ) + err := ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_APPC")) + if err != nil { + log.Fatal("ff.Parse") + } + + var ts tsnet.Server + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + run(ctx, &ts, *wgPort, *hostname, *promoteHTTPS, *debugPort, *ports, *forwards) +} + +// run actually runs the sniproxy. Its separate from main() to assist in testing. +func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, promoteHTTPS bool, debugPort int, ports, forwards string) { + // Wire up Tailscale node + app connector server + hostinfo.SetApp("sniproxy") + var s sniproxy + s.ts = ts + + s.ts.Port = uint16(wgPort) + s.ts.Hostname = hostname + + lc, err := s.ts.LocalClient() + if err != nil { + log.Fatalf("LocalClient() failed: %v", err) + } + s.lc = lc + s.ts.RegisterFallbackTCPHandler(s.srv.HandleTCPFlow) + + // Start special-purpose listeners: dns, http promotion, debug server + ln, err := s.ts.Listen("udp", ":53") + if err != nil { + log.Fatalf("failed listening on port 53: %v", err) + } + defer ln.Close() + go s.serveDNS(ln) + if promoteHTTPS { + ln, err := s.ts.Listen("tcp", ":80") + if err != nil { + log.Fatalf("failed listening on port 80: %v", err) + } + defer ln.Close() + log.Printf("Promoting HTTP to HTTPS ...") + go s.promoteHTTPS(ln) + } + if debugPort != 0 { + mux := http.NewServeMux() + tsweb.Debugger(mux) + dln, err := s.ts.Listen("tcp", fmt.Sprintf(":%d", debugPort)) + if err != nil { + log.Fatalf("failed listening on debug port: %v", err) + } + defer dln.Close() + go func() { + log.Fatalf("debug serve: %v", http.Serve(dln, mux)) + }() + } + + // Finally, start mainloop to configure app connector based on information + // in the netmap. + // We set the NotifyInitialNetMap flag so we will always get woken with the + // current netmap, before only being woken on changes. + bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap|ipn.NotifyNoPrivateKeys) + if err != nil { + log.Fatalf("watching IPN bus: %v", err) + } + defer bus.Close() + for { + msg, err := bus.Next() + if err != nil { + if errors.Is(err, context.Canceled) { + return + } + log.Fatalf("reading IPN bus: %v", err) + } + + // NetMap contains app-connector configuration + if nm := msg.NetMap; nm != nil && nm.SelfNode.Valid() { + sn := nm.SelfNode.AsStruct() + + var c appctype.AppConnectorConfig + nmConf, err := tailcfg.UnmarshalNodeCapJSON[appctype.AppConnectorConfig](sn.CapMap, configCapKey) + if err != nil { + log.Printf("failed to read app connector configuration from coordination server: %v", err) + } else if len(nmConf) > 0 { + c = nmConf[0] + } + + if c.AdvertiseRoutes { + if err := s.advertiseRoutesFromConfig(ctx, &c); err != nil { + log.Printf("failed to advertise routes: %v", err) + } + } + + // Backwards compatibility: combine any configuration from control with flags specified + // on the command line. This is intentionally done after we advertise any routes + // because its never correct to advertise the nodes native IP addresses. + s.mergeConfigFromFlags(&c, ports, forwards) + s.srv.Configure(&c) + } + } +} + +type sniproxy struct { + srv Server + ts *tsnet.Server + lc *tailscale.LocalClient +} + +func (s *sniproxy) advertiseRoutesFromConfig(ctx context.Context, c *appctype.AppConnectorConfig) error { + // Collect the set of addresses to advertise, using a map + // to avoid duplicate entries. + addrs := map[netip.Addr]struct{}{} + for _, c := range c.SNIProxy { + for _, ip := range c.Addrs { + addrs[ip] = struct{}{} + } + } + for _, c := range c.DNAT { + for _, ip := range c.Addrs { + addrs[ip] = struct{}{} + } + } + + var routes []netip.Prefix + for a := range addrs { + routes = append(routes, netip.PrefixFrom(a, a.BitLen())) + } + sort.SliceStable(routes, func(i, j int) bool { + return routes[i].Addr().Less(routes[j].Addr()) // determinism r us + }) + + _, err := s.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + AdvertiseRoutes: routes, + }, + AdvertiseRoutesSet: true, + }) + return err +} + +func (s *sniproxy) mergeConfigFromFlags(out *appctype.AppConnectorConfig, ports, forwards string) { + ip4, ip6 := s.ts.TailscaleIPs() + + sniConfigFromFlags := appctype.SNIProxyConfig{ + Addrs: []netip.Addr{ip4, ip6}, + } + if ports != "" { + for _, portStr := range strings.Split(ports, ",") { + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + log.Fatalf("invalid port: %s", portStr) + } + sniConfigFromFlags.IP = append(sniConfigFromFlags.IP, tailcfg.ProtoPortRange{ + Proto: int(ipproto.TCP), + Ports: tailcfg.PortRange{First: uint16(port), Last: uint16(port)}, + }) + } + } + + var forwardConfigFromFlags []appctype.DNATConfig + for _, forwStr := range strings.Split(forwards, ",") { + if forwStr == "" { + continue + } + forw, err := parseForward(forwStr) + if err != nil { + log.Printf("invalid forwarding spec: %v", err) + continue + } + + forwardConfigFromFlags = append(forwardConfigFromFlags, appctype.DNATConfig{ + Addrs: []netip.Addr{ip4, ip6}, + To: []string{forw.Destination}, + IP: []tailcfg.ProtoPortRange{ + { + Proto: int(ipproto.TCP), + Ports: tailcfg.PortRange{First: uint16(forw.Port), Last: uint16(forw.Port)}, + }, + }, + }) + } + + if len(forwardConfigFromFlags) == 0 && len(sniConfigFromFlags.IP) == 0 { + return // no config specified on the command line + } + + mak.Set(&out.SNIProxy, "flags", sniConfigFromFlags) + for i, forward := range forwardConfigFromFlags { + mak.Set(&out.DNAT, appctype.ConfigID(fmt.Sprintf("flags_%d", i)), forward) + } +} + +func (s *sniproxy) serveDNS(ln net.Listener) { + for { + c, err := ln.Accept() + if err != nil { + log.Printf("serveDNS accept: %v", err) + return + } + go s.srv.HandleDNS(c.(nettype.ConnPacketConn)) + } +} + +func (s *sniproxy) promoteHTTPS(ln net.Listener) { + err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound) + })) + log.Fatalf("promoteHTTPS http.Serve: %v", err) +} diff --git a/cmd/speedtest/speedtest.go b/cmd/speedtest/speedtest.go index 1555c0dcc..9a457ed6c 100644 --- a/cmd/speedtest/speedtest.go +++ b/cmd/speedtest/speedtest.go @@ -1,121 +1,121 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Program speedtest provides the speedtest command. The reason to keep it separate from -// the normal tailscale cli is because it is not yet ready to go in the tailscale binary. -// It will be included in the tailscale cli after it has been added to tailscaled. - -// Example usage for client command: go run cmd/speedtest -host 127.0.0.1:20333 -t 5s -// This will connect to the server on 127.0.0.1:20333 and start a 5 second download speedtest. -// Example usage for server command: go run cmd/speedtest -s -host :20333 -// This will start a speedtest server on port 20333. -package main - -import ( - "context" - "errors" - "flag" - "fmt" - "net" - "os" - "strconv" - "text/tabwriter" - "time" - - "github.com/peterbourgon/ff/v3/ffcli" - "tailscale.com/net/speedtest" -) - -// Runs the speedtest command as a commandline program -func main() { - args := os.Args[1:] - if err := speedtestCmd.Parse(args); err != nil { - fmt.Fprintln(os.Stderr, err.Error()) - os.Exit(1) - } - - err := speedtestCmd.Run(context.Background()) - if errors.Is(err, flag.ErrHelp) { - fmt.Fprintln(os.Stderr, speedtestCmd.ShortUsage) - os.Exit(2) - } - if err != nil { - fmt.Fprintln(os.Stderr, err.Error()) - os.Exit(1) - } -} - -// speedtestCmd is the root command. It runs either the server or client depending on the -// flags passed to it. -var speedtestCmd = &ffcli.Command{ - Name: "speedtest", - ShortUsage: "speedtest [-host ] [-s] [-r] [-t ]", - ShortHelp: "Run a speed test", - FlagSet: (func() *flag.FlagSet { - fs := flag.NewFlagSet("speedtest", flag.ExitOnError) - fs.StringVar(&speedtestArgs.host, "host", ":20333", "host:port pair to connect to or listen on") - fs.DurationVar(&speedtestArgs.testDuration, "t", speedtest.DefaultDuration, "duration of the speed test") - fs.BoolVar(&speedtestArgs.runServer, "s", false, "run a speedtest server") - fs.BoolVar(&speedtestArgs.reverse, "r", false, "run in reverse mode (server sends, client receives)") - return fs - })(), - Exec: runSpeedtest, -} - -var speedtestArgs struct { - host string - testDuration time.Duration - runServer bool - reverse bool -} - -func runSpeedtest(ctx context.Context, args []string) error { - - if _, _, err := net.SplitHostPort(speedtestArgs.host); err != nil { - var addrErr *net.AddrError - if errors.As(err, &addrErr) && addrErr.Err == "missing port in address" { - // if no port is provided, append the default port - speedtestArgs.host = net.JoinHostPort(speedtestArgs.host, strconv.Itoa(speedtest.DefaultPort)) - } - } - - if speedtestArgs.runServer { - listener, err := net.Listen("tcp", speedtestArgs.host) - if err != nil { - return err - } - - fmt.Printf("listening on %v\n", listener.Addr()) - - return speedtest.Serve(listener) - } - - // Ensure the duration is within the allowed range - if speedtestArgs.testDuration < speedtest.MinDuration || speedtestArgs.testDuration > speedtest.MaxDuration { - return fmt.Errorf("test duration must be within %v and %v", speedtest.MinDuration, speedtest.MaxDuration) - } - - dir := speedtest.Download - if speedtestArgs.reverse { - dir = speedtest.Upload - } - - fmt.Printf("Starting a %s test with %s\n", dir, speedtestArgs.host) - results, err := speedtest.RunClient(dir, speedtestArgs.testDuration, speedtestArgs.host) - if err != nil { - return err - } - - w := tabwriter.NewWriter(os.Stdout, 12, 0, 0, ' ', tabwriter.TabIndent) - fmt.Println("Results:") - fmt.Fprintln(w, "Interval\t\tTransfer\t\tBandwidth\t\t") - startTime := results[0].IntervalStart - for _, r := range results { - if r.Total { - fmt.Fprintln(w, "-------------------------------------------------------------------------") - } - fmt.Fprintf(w, "%.2f-%.2f\tsec\t%.4f\tMBits\t%.4f\tMbits/sec\t\n", r.IntervalStart.Sub(startTime).Seconds(), r.IntervalEnd.Sub(startTime).Seconds(), r.MegaBits(), r.MBitsPerSecond()) - } - w.Flush() - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Program speedtest provides the speedtest command. The reason to keep it separate from +// the normal tailscale cli is because it is not yet ready to go in the tailscale binary. +// It will be included in the tailscale cli after it has been added to tailscaled. + +// Example usage for client command: go run cmd/speedtest -host 127.0.0.1:20333 -t 5s +// This will connect to the server on 127.0.0.1:20333 and start a 5 second download speedtest. +// Example usage for server command: go run cmd/speedtest -s -host :20333 +// This will start a speedtest server on port 20333. +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "net" + "os" + "strconv" + "text/tabwriter" + "time" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/net/speedtest" +) + +// Runs the speedtest command as a commandline program +func main() { + args := os.Args[1:] + if err := speedtestCmd.Parse(args); err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } + + err := speedtestCmd.Run(context.Background()) + if errors.Is(err, flag.ErrHelp) { + fmt.Fprintln(os.Stderr, speedtestCmd.ShortUsage) + os.Exit(2) + } + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } +} + +// speedtestCmd is the root command. It runs either the server or client depending on the +// flags passed to it. +var speedtestCmd = &ffcli.Command{ + Name: "speedtest", + ShortUsage: "speedtest [-host ] [-s] [-r] [-t ]", + ShortHelp: "Run a speed test", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("speedtest", flag.ExitOnError) + fs.StringVar(&speedtestArgs.host, "host", ":20333", "host:port pair to connect to or listen on") + fs.DurationVar(&speedtestArgs.testDuration, "t", speedtest.DefaultDuration, "duration of the speed test") + fs.BoolVar(&speedtestArgs.runServer, "s", false, "run a speedtest server") + fs.BoolVar(&speedtestArgs.reverse, "r", false, "run in reverse mode (server sends, client receives)") + return fs + })(), + Exec: runSpeedtest, +} + +var speedtestArgs struct { + host string + testDuration time.Duration + runServer bool + reverse bool +} + +func runSpeedtest(ctx context.Context, args []string) error { + + if _, _, err := net.SplitHostPort(speedtestArgs.host); err != nil { + var addrErr *net.AddrError + if errors.As(err, &addrErr) && addrErr.Err == "missing port in address" { + // if no port is provided, append the default port + speedtestArgs.host = net.JoinHostPort(speedtestArgs.host, strconv.Itoa(speedtest.DefaultPort)) + } + } + + if speedtestArgs.runServer { + listener, err := net.Listen("tcp", speedtestArgs.host) + if err != nil { + return err + } + + fmt.Printf("listening on %v\n", listener.Addr()) + + return speedtest.Serve(listener) + } + + // Ensure the duration is within the allowed range + if speedtestArgs.testDuration < speedtest.MinDuration || speedtestArgs.testDuration > speedtest.MaxDuration { + return fmt.Errorf("test duration must be within %v and %v", speedtest.MinDuration, speedtest.MaxDuration) + } + + dir := speedtest.Download + if speedtestArgs.reverse { + dir = speedtest.Upload + } + + fmt.Printf("Starting a %s test with %s\n", dir, speedtestArgs.host) + results, err := speedtest.RunClient(dir, speedtestArgs.testDuration, speedtestArgs.host) + if err != nil { + return err + } + + w := tabwriter.NewWriter(os.Stdout, 12, 0, 0, ' ', tabwriter.TabIndent) + fmt.Println("Results:") + fmt.Fprintln(w, "Interval\t\tTransfer\t\tBandwidth\t\t") + startTime := results[0].IntervalStart + for _, r := range results { + if r.Total { + fmt.Fprintln(w, "-------------------------------------------------------------------------") + } + fmt.Fprintf(w, "%.2f-%.2f\tsec\t%.4f\tMBits\t%.4f\tMbits/sec\t\n", r.IntervalStart.Sub(startTime).Seconds(), r.IntervalEnd.Sub(startTime).Seconds(), r.MegaBits(), r.MBitsPerSecond()) + } + w.Flush() + return nil +} diff --git a/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go b/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go index ade272c4b..ee929299a 100644 --- a/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go +++ b/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go @@ -1,187 +1,187 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// ssh-auth-none-demo is a demo SSH server that's meant to run on the -// public internet (at 188.166.70.128 port 2222) and -// highlight the unique parts of the Tailscale SSH server so SSH -// client authors can hit it easily and fix their SSH clients without -// needing to set up Tailscale and Tailscale SSH. -package main - -import ( - "crypto/ecdsa" - "crypto/ed25519" - "crypto/elliptic" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "flag" - "fmt" - "io" - "log" - "os" - "path/filepath" - "time" - - gossh "github.com/tailscale/golang-x-crypto/ssh" - "tailscale.com/tempfork/gliderlabs/ssh" -) - -// keyTypes are the SSH key types that we either try to read from the -// system's OpenSSH keys. -var keyTypes = []string{"rsa", "ecdsa", "ed25519"} - -var ( - addr = flag.String("addr", ":2222", "address to listen on") -) - -func main() { - flag.Parse() - - cacheDir, err := os.UserCacheDir() - if err != nil { - log.Fatal(err) - } - dir := filepath.Join(cacheDir, "ssh-auth-none-demo") - if err := os.MkdirAll(dir, 0700); err != nil { - log.Fatal(err) - } - - keys, err := getHostKeys(dir) - if err != nil { - log.Fatal(err) - } - if len(keys) == 0 { - log.Fatal("no host keys") - } - - srv := &ssh.Server{ - Addr: *addr, - Version: "Tailscale", - Handler: handleSessionPostSSHAuth, - ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { - start := time.Now() - return &gossh.ServerConfig{ - NextAuthMethodCallback: func(conn gossh.ConnMetadata, prevErrors []error) []string { - return []string{"tailscale"} - }, - NoClientAuth: true, // required for the NoClientAuthCallback to run - NoClientAuthCallback: func(cm gossh.ConnMetadata) (*gossh.Permissions, error) { - cm.SendAuthBanner(fmt.Sprintf("# Banner: doing none auth at %v\r\n", time.Since(start))) - - totalBanners := 2 - if cm.User() == "banners" { - totalBanners = 5 - } - for banner := 2; banner <= totalBanners; banner++ { - time.Sleep(time.Second) - if banner == totalBanners { - cm.SendAuthBanner(fmt.Sprintf("# Banner%d: access granted at %v\r\n", banner, time.Since(start))) - } else { - cm.SendAuthBanner(fmt.Sprintf("# Banner%d at %v\r\n", banner, time.Since(start))) - } - } - return nil, nil - }, - BannerCallback: func(cm gossh.ConnMetadata) string { - log.Printf("Got connection from user %q, %q from %v", cm.User(), cm.ClientVersion(), cm.RemoteAddr()) - return fmt.Sprintf("# Banner for user %q, %q\n", cm.User(), cm.ClientVersion()) - }, - } - }, - } - - for _, signer := range keys { - srv.AddHostKey(signer) - } - - log.Printf("Running on %s ...", srv.Addr) - if err := srv.ListenAndServe(); err != nil { - log.Fatal(err) - } - log.Printf("done") -} - -func handleSessionPostSSHAuth(s ssh.Session) { - log.Printf("Started session from user %q", s.User()) - fmt.Fprintf(s, "Hello user %q, it worked.\n", s.User()) - - // Abort the session on Control-C or Control-D. - go func() { - buf := make([]byte, 1024) - for { - n, err := s.Read(buf) - for _, b := range buf[:n] { - if b <= 4 { // abort on Control-C (3) or Control-D (4) - io.WriteString(s, "bye\n") - s.Exit(1) - } - } - if err != nil { - return - } - } - }() - - for i := 10; i > 0; i-- { - fmt.Fprintf(s, "%v ...\n", i) - time.Sleep(time.Second) - } - s.Exit(0) -} - -func getHostKeys(dir string) (ret []ssh.Signer, err error) { - for _, typ := range keyTypes { - hostKey, err := hostKeyFileOrCreate(dir, typ) - if err != nil { - return nil, err - } - signer, err := gossh.ParsePrivateKey(hostKey) - if err != nil { - return nil, err - } - ret = append(ret, signer) - } - return ret, nil -} - -func hostKeyFileOrCreate(keyDir, typ string) ([]byte, error) { - path := filepath.Join(keyDir, "ssh_host_"+typ+"_key") - v, err := os.ReadFile(path) - if err == nil { - return v, nil - } - if !os.IsNotExist(err) { - return nil, err - } - var priv any - switch typ { - default: - return nil, fmt.Errorf("unsupported key type %q", typ) - case "ed25519": - _, priv, err = ed25519.GenerateKey(rand.Reader) - case "ecdsa": - // curve is arbitrary. We pick whatever will at - // least pacify clients as the actual encryption - // doesn't matter: it's all over WireGuard anyway. - curve := elliptic.P256() - priv, err = ecdsa.GenerateKey(curve, rand.Reader) - case "rsa": - // keySize is arbitrary. We pick whatever will at - // least pacify clients as the actual encryption - // doesn't matter: it's all over WireGuard anyway. - const keySize = 2048 - priv, err = rsa.GenerateKey(rand.Reader, keySize) - } - if err != nil { - return nil, err - } - mk, err := x509.MarshalPKCS8PrivateKey(priv) - if err != nil { - return nil, err - } - pemGen := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: mk}) - err = os.WriteFile(path, pemGen, 0700) - return pemGen, err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// ssh-auth-none-demo is a demo SSH server that's meant to run on the +// public internet (at 188.166.70.128 port 2222) and +// highlight the unique parts of the Tailscale SSH server so SSH +// client authors can hit it easily and fix their SSH clients without +// needing to set up Tailscale and Tailscale SSH. +package main + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "flag" + "fmt" + "io" + "log" + "os" + "path/filepath" + "time" + + gossh "github.com/tailscale/golang-x-crypto/ssh" + "tailscale.com/tempfork/gliderlabs/ssh" +) + +// keyTypes are the SSH key types that we either try to read from the +// system's OpenSSH keys. +var keyTypes = []string{"rsa", "ecdsa", "ed25519"} + +var ( + addr = flag.String("addr", ":2222", "address to listen on") +) + +func main() { + flag.Parse() + + cacheDir, err := os.UserCacheDir() + if err != nil { + log.Fatal(err) + } + dir := filepath.Join(cacheDir, "ssh-auth-none-demo") + if err := os.MkdirAll(dir, 0700); err != nil { + log.Fatal(err) + } + + keys, err := getHostKeys(dir) + if err != nil { + log.Fatal(err) + } + if len(keys) == 0 { + log.Fatal("no host keys") + } + + srv := &ssh.Server{ + Addr: *addr, + Version: "Tailscale", + Handler: handleSessionPostSSHAuth, + ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { + start := time.Now() + return &gossh.ServerConfig{ + NextAuthMethodCallback: func(conn gossh.ConnMetadata, prevErrors []error) []string { + return []string{"tailscale"} + }, + NoClientAuth: true, // required for the NoClientAuthCallback to run + NoClientAuthCallback: func(cm gossh.ConnMetadata) (*gossh.Permissions, error) { + cm.SendAuthBanner(fmt.Sprintf("# Banner: doing none auth at %v\r\n", time.Since(start))) + + totalBanners := 2 + if cm.User() == "banners" { + totalBanners = 5 + } + for banner := 2; banner <= totalBanners; banner++ { + time.Sleep(time.Second) + if banner == totalBanners { + cm.SendAuthBanner(fmt.Sprintf("# Banner%d: access granted at %v\r\n", banner, time.Since(start))) + } else { + cm.SendAuthBanner(fmt.Sprintf("# Banner%d at %v\r\n", banner, time.Since(start))) + } + } + return nil, nil + }, + BannerCallback: func(cm gossh.ConnMetadata) string { + log.Printf("Got connection from user %q, %q from %v", cm.User(), cm.ClientVersion(), cm.RemoteAddr()) + return fmt.Sprintf("# Banner for user %q, %q\n", cm.User(), cm.ClientVersion()) + }, + } + }, + } + + for _, signer := range keys { + srv.AddHostKey(signer) + } + + log.Printf("Running on %s ...", srv.Addr) + if err := srv.ListenAndServe(); err != nil { + log.Fatal(err) + } + log.Printf("done") +} + +func handleSessionPostSSHAuth(s ssh.Session) { + log.Printf("Started session from user %q", s.User()) + fmt.Fprintf(s, "Hello user %q, it worked.\n", s.User()) + + // Abort the session on Control-C or Control-D. + go func() { + buf := make([]byte, 1024) + for { + n, err := s.Read(buf) + for _, b := range buf[:n] { + if b <= 4 { // abort on Control-C (3) or Control-D (4) + io.WriteString(s, "bye\n") + s.Exit(1) + } + } + if err != nil { + return + } + } + }() + + for i := 10; i > 0; i-- { + fmt.Fprintf(s, "%v ...\n", i) + time.Sleep(time.Second) + } + s.Exit(0) +} + +func getHostKeys(dir string) (ret []ssh.Signer, err error) { + for _, typ := range keyTypes { + hostKey, err := hostKeyFileOrCreate(dir, typ) + if err != nil { + return nil, err + } + signer, err := gossh.ParsePrivateKey(hostKey) + if err != nil { + return nil, err + } + ret = append(ret, signer) + } + return ret, nil +} + +func hostKeyFileOrCreate(keyDir, typ string) ([]byte, error) { + path := filepath.Join(keyDir, "ssh_host_"+typ+"_key") + v, err := os.ReadFile(path) + if err == nil { + return v, nil + } + if !os.IsNotExist(err) { + return nil, err + } + var priv any + switch typ { + default: + return nil, fmt.Errorf("unsupported key type %q", typ) + case "ed25519": + _, priv, err = ed25519.GenerateKey(rand.Reader) + case "ecdsa": + // curve is arbitrary. We pick whatever will at + // least pacify clients as the actual encryption + // doesn't matter: it's all over WireGuard anyway. + curve := elliptic.P256() + priv, err = ecdsa.GenerateKey(curve, rand.Reader) + case "rsa": + // keySize is arbitrary. We pick whatever will at + // least pacify clients as the actual encryption + // doesn't matter: it's all over WireGuard anyway. + const keySize = 2048 + priv, err = rsa.GenerateKey(rand.Reader, keySize) + } + if err != nil { + return nil, err + } + mk, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + return nil, err + } + pemGen := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: mk}) + err = os.WriteFile(path, pemGen, 0700) + return pemGen, err +} diff --git a/cmd/sync-containers/main.go b/cmd/sync-containers/main.go index 68308cfeb..6317b4943 100644 --- a/cmd/sync-containers/main.go +++ b/cmd/sync-containers/main.go @@ -1,214 +1,214 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -// The sync-containers command synchronizes container image tags from one -// registry to another. -// -// It is intended as a workaround for ghcr.io's lack of good push credentials: -// you can either authorize "classic" Personal Access Tokens in your org (which -// are a common vector of very bad compromise), or you can get a short-lived -// credential in a Github action. -// -// Since we publish to both Docker Hub and ghcr.io, we use this program in a -// Github action to effectively rsync from docker hub into ghcr.io, so that we -// can continue to forbid dangerous Personal Access Tokens in the tailscale org. -package main - -import ( - "context" - "flag" - "fmt" - "log" - "sort" - "strings" - - "github.com/google/go-containerregistry/pkg/authn" - "github.com/google/go-containerregistry/pkg/authn/github" - "github.com/google/go-containerregistry/pkg/name" - v1 "github.com/google/go-containerregistry/pkg/v1" - "github.com/google/go-containerregistry/pkg/v1/remote" - "github.com/google/go-containerregistry/pkg/v1/types" -) - -var ( - src = flag.String("src", "", "Source image") - dst = flag.String("dst", "", "Destination image") - max = flag.Int("max", 0, "Maximum number of tags to sync (0 for all tags)") - dryRun = flag.Bool("dry-run", true, "Don't actually sync anything") -) - -func main() { - flag.Parse() - - if *src == "" { - log.Fatalf("--src is required") - } - if *dst == "" { - log.Fatalf("--dst is required") - } - - keychain := authn.NewMultiKeychain(authn.DefaultKeychain, github.Keychain) - opts := []remote.Option{ - remote.WithAuthFromKeychain(keychain), - remote.WithContext(context.Background()), - } - - stags, err := listTags(*src, opts...) - if err != nil { - log.Fatalf("listing source tags: %v", err) - } - dtags, err := listTags(*dst, opts...) - if err != nil { - log.Fatalf("listing destination tags: %v", err) - } - - add, remove := diffTags(stags, dtags) - if l := len(add); l > 0 { - log.Printf("%d tags to push: %s", len(add), strings.Join(add, ", ")) - if *max > 0 && l > *max { - log.Printf("Limiting sync to %d tags", *max) - add = add[:*max] - } - } - for _, tag := range add { - if !*dryRun { - log.Printf("Syncing tag %q", tag) - if err := copyTag(*src, *dst, tag, opts...); err != nil { - log.Printf("Syncing tag %q: progress error: %v", tag, err) - } - } else { - log.Printf("Dry run: would sync tag %q", tag) - } - } - - if len(remove) > 0 { - log.Printf("%d tags to remove: %s\n", len(remove), strings.Join(remove, ", ")) - log.Printf("Not removing any tags for safety.\n") - } - - var wellKnown = [...]string{"latest", "stable"} - for _, tag := range wellKnown { - if needsUpdate(*src, *dst, tag) { - if err := copyTag(*src, *dst, tag, opts...); err != nil { - log.Printf("Updating tag %q: progress error: %v", tag, err) - } - } - } -} - -func copyTag(srcStr, dstStr, tag string, opts ...remote.Option) error { - src, err := name.ParseReference(fmt.Sprintf("%s:%s", srcStr, tag)) - if err != nil { - return err - } - dst, err := name.ParseReference(fmt.Sprintf("%s:%s", dstStr, tag)) - if err != nil { - return err - } - - desc, err := remote.Get(src) - if err != nil { - return err - } - - ch := make(chan v1.Update, 10) - opts = append(opts, remote.WithProgress(ch)) - progressDone := make(chan struct{}) - - go func() { - defer close(progressDone) - for p := range ch { - fmt.Printf("Syncing tag %q: %d%% (%d/%d)\n", tag, int(float64(p.Complete)/float64(p.Total)*100), p.Complete, p.Total) - if p.Error != nil { - fmt.Printf("error: %v\n", p.Error) - } - } - }() - - switch desc.MediaType { - case types.OCIManifestSchema1, types.DockerManifestSchema2: - img, err := desc.Image() - if err != nil { - return err - } - if err := remote.Write(dst, img, opts...); err != nil { - return err - } - case types.OCIImageIndex, types.DockerManifestList: - idx, err := desc.ImageIndex() - if err != nil { - return err - } - if err := remote.WriteIndex(dst, idx, opts...); err != nil { - return err - } - } - - <-progressDone - return nil -} - -func listTags(repoStr string, opts ...remote.Option) ([]string, error) { - repo, err := name.NewRepository(repoStr) - if err != nil { - return nil, err - } - - tags, err := remote.List(repo, opts...) - if err != nil { - return nil, err - } - - sort.Strings(tags) - return tags, nil -} - -func diffTags(src, dst []string) (add, remove []string) { - srcd := make(map[string]bool) - for _, tag := range src { - srcd[tag] = true - } - dstd := make(map[string]bool) - for _, tag := range dst { - dstd[tag] = true - } - - for _, tag := range src { - if !dstd[tag] { - add = append(add, tag) - } - } - for _, tag := range dst { - if !srcd[tag] { - remove = append(remove, tag) - } - } - sort.Strings(add) - sort.Strings(remove) - return add, remove -} - -func needsUpdate(srcStr, dstStr, tag string) bool { - src, err := name.ParseReference(fmt.Sprintf("%s:%s", srcStr, tag)) - if err != nil { - return false - } - dst, err := name.ParseReference(fmt.Sprintf("%s:%s", dstStr, tag)) - if err != nil { - return false - } - - srcDesc, err := remote.Get(src) - if err != nil { - return false - } - - dstDesc, err := remote.Get(dst) - if err != nil { - return true - } - - return srcDesc.Digest != dstDesc.Digest -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// The sync-containers command synchronizes container image tags from one +// registry to another. +// +// It is intended as a workaround for ghcr.io's lack of good push credentials: +// you can either authorize "classic" Personal Access Tokens in your org (which +// are a common vector of very bad compromise), or you can get a short-lived +// credential in a Github action. +// +// Since we publish to both Docker Hub and ghcr.io, we use this program in a +// Github action to effectively rsync from docker hub into ghcr.io, so that we +// can continue to forbid dangerous Personal Access Tokens in the tailscale org. +package main + +import ( + "context" + "flag" + "fmt" + "log" + "sort" + "strings" + + "github.com/google/go-containerregistry/pkg/authn" + "github.com/google/go-containerregistry/pkg/authn/github" + "github.com/google/go-containerregistry/pkg/name" + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/remote" + "github.com/google/go-containerregistry/pkg/v1/types" +) + +var ( + src = flag.String("src", "", "Source image") + dst = flag.String("dst", "", "Destination image") + max = flag.Int("max", 0, "Maximum number of tags to sync (0 for all tags)") + dryRun = flag.Bool("dry-run", true, "Don't actually sync anything") +) + +func main() { + flag.Parse() + + if *src == "" { + log.Fatalf("--src is required") + } + if *dst == "" { + log.Fatalf("--dst is required") + } + + keychain := authn.NewMultiKeychain(authn.DefaultKeychain, github.Keychain) + opts := []remote.Option{ + remote.WithAuthFromKeychain(keychain), + remote.WithContext(context.Background()), + } + + stags, err := listTags(*src, opts...) + if err != nil { + log.Fatalf("listing source tags: %v", err) + } + dtags, err := listTags(*dst, opts...) + if err != nil { + log.Fatalf("listing destination tags: %v", err) + } + + add, remove := diffTags(stags, dtags) + if l := len(add); l > 0 { + log.Printf("%d tags to push: %s", len(add), strings.Join(add, ", ")) + if *max > 0 && l > *max { + log.Printf("Limiting sync to %d tags", *max) + add = add[:*max] + } + } + for _, tag := range add { + if !*dryRun { + log.Printf("Syncing tag %q", tag) + if err := copyTag(*src, *dst, tag, opts...); err != nil { + log.Printf("Syncing tag %q: progress error: %v", tag, err) + } + } else { + log.Printf("Dry run: would sync tag %q", tag) + } + } + + if len(remove) > 0 { + log.Printf("%d tags to remove: %s\n", len(remove), strings.Join(remove, ", ")) + log.Printf("Not removing any tags for safety.\n") + } + + var wellKnown = [...]string{"latest", "stable"} + for _, tag := range wellKnown { + if needsUpdate(*src, *dst, tag) { + if err := copyTag(*src, *dst, tag, opts...); err != nil { + log.Printf("Updating tag %q: progress error: %v", tag, err) + } + } + } +} + +func copyTag(srcStr, dstStr, tag string, opts ...remote.Option) error { + src, err := name.ParseReference(fmt.Sprintf("%s:%s", srcStr, tag)) + if err != nil { + return err + } + dst, err := name.ParseReference(fmt.Sprintf("%s:%s", dstStr, tag)) + if err != nil { + return err + } + + desc, err := remote.Get(src) + if err != nil { + return err + } + + ch := make(chan v1.Update, 10) + opts = append(opts, remote.WithProgress(ch)) + progressDone := make(chan struct{}) + + go func() { + defer close(progressDone) + for p := range ch { + fmt.Printf("Syncing tag %q: %d%% (%d/%d)\n", tag, int(float64(p.Complete)/float64(p.Total)*100), p.Complete, p.Total) + if p.Error != nil { + fmt.Printf("error: %v\n", p.Error) + } + } + }() + + switch desc.MediaType { + case types.OCIManifestSchema1, types.DockerManifestSchema2: + img, err := desc.Image() + if err != nil { + return err + } + if err := remote.Write(dst, img, opts...); err != nil { + return err + } + case types.OCIImageIndex, types.DockerManifestList: + idx, err := desc.ImageIndex() + if err != nil { + return err + } + if err := remote.WriteIndex(dst, idx, opts...); err != nil { + return err + } + } + + <-progressDone + return nil +} + +func listTags(repoStr string, opts ...remote.Option) ([]string, error) { + repo, err := name.NewRepository(repoStr) + if err != nil { + return nil, err + } + + tags, err := remote.List(repo, opts...) + if err != nil { + return nil, err + } + + sort.Strings(tags) + return tags, nil +} + +func diffTags(src, dst []string) (add, remove []string) { + srcd := make(map[string]bool) + for _, tag := range src { + srcd[tag] = true + } + dstd := make(map[string]bool) + for _, tag := range dst { + dstd[tag] = true + } + + for _, tag := range src { + if !dstd[tag] { + add = append(add, tag) + } + } + for _, tag := range dst { + if !srcd[tag] { + remove = append(remove, tag) + } + } + sort.Strings(add) + sort.Strings(remove) + return add, remove +} + +func needsUpdate(srcStr, dstStr, tag string) bool { + src, err := name.ParseReference(fmt.Sprintf("%s:%s", srcStr, tag)) + if err != nil { + return false + } + dst, err := name.ParseReference(fmt.Sprintf("%s:%s", dstStr, tag)) + if err != nil { + return false + } + + srcDesc, err := remote.Get(src) + if err != nil { + return false + } + + dstDesc, err := remote.Get(dst) + if err != nil { + return true + } + + return srcDesc.Digest != dstDesc.Digest +} diff --git a/cmd/tailscale/cli/diag.go b/cmd/tailscale/cli/diag.go index a1616f851..ebf26985f 100644 --- a/cmd/tailscale/cli/diag.go +++ b/cmd/tailscale/cli/diag.go @@ -1,74 +1,74 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux || windows || darwin - -package cli - -import ( - "fmt" - "os/exec" - "path/filepath" - "runtime" - "strings" - - ps "github.com/mitchellh/go-ps" - "tailscale.com/version/distro" -) - -// fixTailscaledConnectError is called when the local tailscaled has -// been determined unreachable due to the provided origErr value. It -// returns either the same error or a better one to help the user -// understand why tailscaled isn't running for their platform. -func fixTailscaledConnectError(origErr error) error { - procs, err := ps.Processes() - if err != nil { - return fmt.Errorf("failed to connect to local Tailscaled process and failed to enumerate processes while looking for it") - } - var foundProc ps.Process - for _, proc := range procs { - base := filepath.Base(proc.Executable()) - if base == "tailscaled" { - foundProc = proc - break - } - if runtime.GOOS == "darwin" && base == "IPNExtension" { - foundProc = proc - break - } - if runtime.GOOS == "windows" && strings.EqualFold(base, "tailscaled.exe") { - foundProc = proc - break - } - } - if foundProc == nil { - switch runtime.GOOS { - case "windows": - return fmt.Errorf("failed to connect to local tailscaled process; is the Tailscale service running?") - case "darwin": - return fmt.Errorf("failed to connect to local Tailscale service; is Tailscale running?") - case "linux": - var hint string - if isSystemdSystem() { - hint = " (sudo systemctl start tailscaled ?)" - } - return fmt.Errorf("failed to connect to local tailscaled; it doesn't appear to be running%s", hint) - } - return fmt.Errorf("failed to connect to local tailscaled process; it doesn't appear to be running") - } - return fmt.Errorf("failed to connect to local tailscaled (which appears to be running as %v, pid %v). Got error: %w", foundProc.Executable(), foundProc.Pid(), origErr) -} - -// isSystemdSystem reports whether the current machine uses systemd -// and in particular whether the systemctl command is available. -func isSystemdSystem() bool { - if runtime.GOOS != "linux" { - return false - } - switch distro.Get() { - case distro.QNAP, distro.Gokrazy, distro.Synology, distro.Unraid: - return false - } - _, err := exec.LookPath("systemctl") - return err == nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux || windows || darwin + +package cli + +import ( + "fmt" + "os/exec" + "path/filepath" + "runtime" + "strings" + + ps "github.com/mitchellh/go-ps" + "tailscale.com/version/distro" +) + +// fixTailscaledConnectError is called when the local tailscaled has +// been determined unreachable due to the provided origErr value. It +// returns either the same error or a better one to help the user +// understand why tailscaled isn't running for their platform. +func fixTailscaledConnectError(origErr error) error { + procs, err := ps.Processes() + if err != nil { + return fmt.Errorf("failed to connect to local Tailscaled process and failed to enumerate processes while looking for it") + } + var foundProc ps.Process + for _, proc := range procs { + base := filepath.Base(proc.Executable()) + if base == "tailscaled" { + foundProc = proc + break + } + if runtime.GOOS == "darwin" && base == "IPNExtension" { + foundProc = proc + break + } + if runtime.GOOS == "windows" && strings.EqualFold(base, "tailscaled.exe") { + foundProc = proc + break + } + } + if foundProc == nil { + switch runtime.GOOS { + case "windows": + return fmt.Errorf("failed to connect to local tailscaled process; is the Tailscale service running?") + case "darwin": + return fmt.Errorf("failed to connect to local Tailscale service; is Tailscale running?") + case "linux": + var hint string + if isSystemdSystem() { + hint = " (sudo systemctl start tailscaled ?)" + } + return fmt.Errorf("failed to connect to local tailscaled; it doesn't appear to be running%s", hint) + } + return fmt.Errorf("failed to connect to local tailscaled process; it doesn't appear to be running") + } + return fmt.Errorf("failed to connect to local tailscaled (which appears to be running as %v, pid %v). Got error: %w", foundProc.Executable(), foundProc.Pid(), origErr) +} + +// isSystemdSystem reports whether the current machine uses systemd +// and in particular whether the systemctl command is available. +func isSystemdSystem() bool { + if runtime.GOOS != "linux" { + return false + } + switch distro.Get() { + case distro.QNAP, distro.Gokrazy, distro.Synology, distro.Unraid: + return false + } + _, err := exec.LookPath("systemctl") + return err == nil +} diff --git a/cmd/tailscale/cli/diag_other.go b/cmd/tailscale/cli/diag_other.go index 82058ef7a..ece10cc79 100644 --- a/cmd/tailscale/cli/diag_other.go +++ b/cmd/tailscale/cli/diag_other.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux && !windows && !darwin - -package cli - -import "fmt" - -// The github.com/mitchellh/go-ps package doesn't work on all platforms, -// so just don't diagnose connect failures. - -func fixTailscaledConnectError(origErr error) error { - return fmt.Errorf("failed to connect to local tailscaled process (is it running?); got: %w", origErr) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !windows && !darwin + +package cli + +import "fmt" + +// The github.com/mitchellh/go-ps package doesn't work on all platforms, +// so just don't diagnose connect failures. + +func fixTailscaledConnectError(origErr error) error { + return fmt.Errorf("failed to connect to local tailscaled process (is it running?); got: %w", origErr) +} diff --git a/cmd/tailscale/cli/set_test.go b/cmd/tailscale/cli/set_test.go index 06ef8503f..15305c3ce 100644 --- a/cmd/tailscale/cli/set_test.go +++ b/cmd/tailscale/cli/set_test.go @@ -1,131 +1,131 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cli - -import ( - "net/netip" - "reflect" - "testing" - - "tailscale.com/ipn" - "tailscale.com/net/tsaddr" - "tailscale.com/types/ptr" -) - -func TestCalcAdvertiseRoutesForSet(t *testing.T) { - pfx := netip.MustParsePrefix - tests := []struct { - name string - setExit *bool - setRoutes *string - was []netip.Prefix - want []netip.Prefix - }{ - { - name: "empty", - }, - { - name: "advertise-exit", - setExit: ptr.To(true), - want: tsaddr.ExitRoutes(), - }, - { - name: "advertise-exit/already-routes", - was: []netip.Prefix{pfx("34.0.0.0/16")}, - setExit: ptr.To(true), - want: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - }, - { - name: "advertise-exit/already-exit", - was: tsaddr.ExitRoutes(), - setExit: ptr.To(true), - want: tsaddr.ExitRoutes(), - }, - { - name: "stop-advertise-exit", - was: tsaddr.ExitRoutes(), - setExit: ptr.To(false), - want: nil, - }, - { - name: "stop-advertise-exit/with-routes", - was: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - setExit: ptr.To(false), - want: []netip.Prefix{pfx("34.0.0.0/16")}, - }, - { - name: "advertise-routes", - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), - want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, - }, - { - name: "advertise-routes/already-exit", - was: tsaddr.ExitRoutes(), - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), - want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - }, - { - name: "advertise-routes/already-diff-routes", - was: []netip.Prefix{pfx("34.0.0.0/16")}, - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), - want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, - }, - { - name: "stop-advertise-routes", - was: []netip.Prefix{pfx("34.0.0.0/16")}, - setRoutes: ptr.To(""), - want: nil, - }, - { - name: "stop-advertise-routes/already-exit", - was: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - setRoutes: ptr.To(""), - want: tsaddr.ExitRoutes(), - }, - { - name: "advertise-routes-and-exit", - setExit: ptr.To(true), - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), - want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - }, - { - name: "advertise-routes-and-exit/already-exit", - was: tsaddr.ExitRoutes(), - setExit: ptr.To(true), - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), - want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - }, - { - name: "advertise-routes-and-exit/already-routes", - was: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, - setExit: ptr.To(true), - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), - want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - curPrefs := &ipn.Prefs{ - AdvertiseRoutes: tc.was, - } - sa := setArgsT{} - if tc.setExit != nil { - sa.advertiseDefaultRoute = *tc.setExit - } - if tc.setRoutes != nil { - sa.advertiseRoutes = *tc.setRoutes - } - got, err := calcAdvertiseRoutesForSet(tc.setExit != nil, tc.setRoutes != nil, curPrefs, sa) - if err != nil { - t.Fatal(err) - } - tsaddr.SortPrefixes(got) - tsaddr.SortPrefixes(tc.want) - if !reflect.DeepEqual(got, tc.want) { - t.Errorf("got %v, want %v", got, tc.want) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "net/netip" + "reflect" + "testing" + + "tailscale.com/ipn" + "tailscale.com/net/tsaddr" + "tailscale.com/types/ptr" +) + +func TestCalcAdvertiseRoutesForSet(t *testing.T) { + pfx := netip.MustParsePrefix + tests := []struct { + name string + setExit *bool + setRoutes *string + was []netip.Prefix + want []netip.Prefix + }{ + { + name: "empty", + }, + { + name: "advertise-exit", + setExit: ptr.To(true), + want: tsaddr.ExitRoutes(), + }, + { + name: "advertise-exit/already-routes", + was: []netip.Prefix{pfx("34.0.0.0/16")}, + setExit: ptr.To(true), + want: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + { + name: "advertise-exit/already-exit", + was: tsaddr.ExitRoutes(), + setExit: ptr.To(true), + want: tsaddr.ExitRoutes(), + }, + { + name: "stop-advertise-exit", + was: tsaddr.ExitRoutes(), + setExit: ptr.To(false), + want: nil, + }, + { + name: "stop-advertise-exit/with-routes", + was: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + setExit: ptr.To(false), + want: []netip.Prefix{pfx("34.0.0.0/16")}, + }, + { + name: "advertise-routes", + setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, + }, + { + name: "advertise-routes/already-exit", + was: tsaddr.ExitRoutes(), + setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + { + name: "advertise-routes/already-diff-routes", + was: []netip.Prefix{pfx("34.0.0.0/16")}, + setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, + }, + { + name: "stop-advertise-routes", + was: []netip.Prefix{pfx("34.0.0.0/16")}, + setRoutes: ptr.To(""), + want: nil, + }, + { + name: "stop-advertise-routes/already-exit", + was: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + setRoutes: ptr.To(""), + want: tsaddr.ExitRoutes(), + }, + { + name: "advertise-routes-and-exit", + setExit: ptr.To(true), + setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + { + name: "advertise-routes-and-exit/already-exit", + was: tsaddr.ExitRoutes(), + setExit: ptr.To(true), + setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + { + name: "advertise-routes-and-exit/already-routes", + was: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, + setExit: ptr.To(true), + setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + curPrefs := &ipn.Prefs{ + AdvertiseRoutes: tc.was, + } + sa := setArgsT{} + if tc.setExit != nil { + sa.advertiseDefaultRoute = *tc.setExit + } + if tc.setRoutes != nil { + sa.advertiseRoutes = *tc.setRoutes + } + got, err := calcAdvertiseRoutesForSet(tc.setExit != nil, tc.setRoutes != nil, curPrefs, sa) + if err != nil { + t.Fatal(err) + } + tsaddr.SortPrefixes(got) + tsaddr.SortPrefixes(tc.want) + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("got %v, want %v", got, tc.want) + } + }) + } +} diff --git a/cmd/tailscale/cli/ssh_exec.go b/cmd/tailscale/cli/ssh_exec.go index 7f7d2a4d5..10e52903d 100644 --- a/cmd/tailscale/cli/ssh_exec.go +++ b/cmd/tailscale/cli/ssh_exec.go @@ -1,24 +1,24 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !js && !windows - -package cli - -import ( - "errors" - "os" - "os/exec" - "syscall" -) - -func findSSH() (string, error) { - return exec.LookPath("ssh") -} - -func execSSH(ssh string, argv []string) error { - if err := syscall.Exec(ssh, argv, os.Environ()); err != nil { - return err - } - return errors.New("unreachable") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !js && !windows + +package cli + +import ( + "errors" + "os" + "os/exec" + "syscall" +) + +func findSSH() (string, error) { + return exec.LookPath("ssh") +} + +func execSSH(ssh string, argv []string) error { + if err := syscall.Exec(ssh, argv, os.Environ()); err != nil { + return err + } + return errors.New("unreachable") +} diff --git a/cmd/tailscale/cli/ssh_exec_js.go b/cmd/tailscale/cli/ssh_exec_js.go index aa0c09e89..40effc7ca 100644 --- a/cmd/tailscale/cli/ssh_exec_js.go +++ b/cmd/tailscale/cli/ssh_exec_js.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cli - -import ( - "errors" -) - -func findSSH() (string, error) { - return "", errors.New("Not implemented") -} - -func execSSH(ssh string, argv []string) error { - return errors.New("Not implemented") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "errors" +) + +func findSSH() (string, error) { + return "", errors.New("Not implemented") +} + +func execSSH(ssh string, argv []string) error { + return errors.New("Not implemented") +} diff --git a/cmd/tailscale/cli/ssh_exec_windows.go b/cmd/tailscale/cli/ssh_exec_windows.go index 30ab70d04..e249afe66 100644 --- a/cmd/tailscale/cli/ssh_exec_windows.go +++ b/cmd/tailscale/cli/ssh_exec_windows.go @@ -1,37 +1,37 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cli - -import ( - "errors" - "os" - "os/exec" - "path/filepath" -) - -func findSSH() (string, error) { - // use C:\Windows\System32\OpenSSH\ssh.exe since unexpected behavior - // occurred with ssh.exe provided by msys2/cygwin and other environments. - if systemRoot := os.Getenv("SystemRoot"); systemRoot != "" { - exe := filepath.Join(systemRoot, "System32", "OpenSSH", "ssh.exe") - if st, err := os.Stat(exe); err == nil && !st.IsDir() { - return exe, nil - } - } - return exec.LookPath("ssh") -} - -func execSSH(ssh string, argv []string) error { - // Don't use syscall.Exec on Windows, it's not fully implemented. - cmd := exec.Command(ssh, argv[1:]...) - cmd.Stdin = os.Stdin - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - var ee *exec.ExitError - err := cmd.Run() - if errors.As(err, &ee) { - os.Exit(ee.ExitCode()) - } - return err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "errors" + "os" + "os/exec" + "path/filepath" +) + +func findSSH() (string, error) { + // use C:\Windows\System32\OpenSSH\ssh.exe since unexpected behavior + // occurred with ssh.exe provided by msys2/cygwin and other environments. + if systemRoot := os.Getenv("SystemRoot"); systemRoot != "" { + exe := filepath.Join(systemRoot, "System32", "OpenSSH", "ssh.exe") + if st, err := os.Stat(exe); err == nil && !st.IsDir() { + return exe, nil + } + } + return exec.LookPath("ssh") +} + +func execSSH(ssh string, argv []string) error { + // Don't use syscall.Exec on Windows, it's not fully implemented. + cmd := exec.Command(ssh, argv[1:]...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + var ee *exec.ExitError + err := cmd.Run() + if errors.As(err, &ee) { + os.Exit(ee.ExitCode()) + } + return err +} diff --git a/cmd/tailscale/cli/ssh_unix.go b/cmd/tailscale/cli/ssh_unix.go index 07423b69f..71c0caaa6 100644 --- a/cmd/tailscale/cli/ssh_unix.go +++ b/cmd/tailscale/cli/ssh_unix.go @@ -1,49 +1,49 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !wasm && !windows && !plan9 - -package cli - -import ( - "bytes" - "os" - "path/filepath" - "runtime" - "strconv" - - "golang.org/x/sys/unix" -) - -func init() { - getSSHClientEnvVar = func() string { - if os.Getenv("SUDO_USER") == "" { - // No sudo, just check the env. - return os.Getenv("SSH_CLIENT") - } - if runtime.GOOS != "linux" { - // TODO(maisem): implement this for other platforms. It's not clear - // if there is a way to get the environment for a given process on - // darwin and bsd. - return "" - } - // SID is the session ID of the user's login session. - // It is also the process ID of the original shell that the user logged in with. - // We only need to check the environment of that process. - sid, err := unix.Getsid(os.Getpid()) - if err != nil { - return "" - } - b, err := os.ReadFile(filepath.Join("/proc", strconv.Itoa(sid), "environ")) - if err != nil { - return "" - } - prefix := []byte("SSH_CLIENT=") - for _, env := range bytes.Split(b, []byte{0}) { - if bytes.HasPrefix(env, prefix) { - return string(env[len(prefix):]) - } - } - return "" - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !wasm && !windows && !plan9 + +package cli + +import ( + "bytes" + "os" + "path/filepath" + "runtime" + "strconv" + + "golang.org/x/sys/unix" +) + +func init() { + getSSHClientEnvVar = func() string { + if os.Getenv("SUDO_USER") == "" { + // No sudo, just check the env. + return os.Getenv("SSH_CLIENT") + } + if runtime.GOOS != "linux" { + // TODO(maisem): implement this for other platforms. It's not clear + // if there is a way to get the environment for a given process on + // darwin and bsd. + return "" + } + // SID is the session ID of the user's login session. + // It is also the process ID of the original shell that the user logged in with. + // We only need to check the environment of that process. + sid, err := unix.Getsid(os.Getpid()) + if err != nil { + return "" + } + b, err := os.ReadFile(filepath.Join("/proc", strconv.Itoa(sid), "environ")) + if err != nil { + return "" + } + prefix := []byte("SSH_CLIENT=") + for _, env := range bytes.Split(b, []byte{0}) { + if bytes.HasPrefix(env, prefix) { + return string(env[len(prefix):]) + } + } + return "" + } +} diff --git a/cmd/tailscale/cli/web_test.go b/cmd/tailscale/cli/web_test.go index f1880597e..f2470b364 100644 --- a/cmd/tailscale/cli/web_test.go +++ b/cmd/tailscale/cli/web_test.go @@ -1,45 +1,45 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cli - -import ( - "testing" -) - -func TestUrlOfListenAddr(t *testing.T) { - tests := []struct { - name string - in, want string - }{ - { - name: "TestLocalhost", - in: "localhost:8088", - want: "http://localhost:8088", - }, - { - name: "TestNoHost", - in: ":8088", - want: "http://127.0.0.1:8088", - }, - { - name: "TestExplicitHost", - in: "127.0.0.2:8088", - want: "http://127.0.0.2:8088", - }, - { - name: "TestIPv6", - in: "[::1]:8088", - want: "http://[::1]:8088", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - u := urlOfListenAddr(tt.in) - if u != tt.want { - t.Errorf("expected url: %q, got: %q", tt.want, u) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "testing" +) + +func TestUrlOfListenAddr(t *testing.T) { + tests := []struct { + name string + in, want string + }{ + { + name: "TestLocalhost", + in: "localhost:8088", + want: "http://localhost:8088", + }, + { + name: "TestNoHost", + in: ":8088", + want: "http://127.0.0.1:8088", + }, + { + name: "TestExplicitHost", + in: "127.0.0.2:8088", + want: "http://127.0.0.2:8088", + }, + { + name: "TestIPv6", + in: "[::1]:8088", + want: "http://[::1]:8088", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u := urlOfListenAddr(tt.in) + if u != tt.want { + t.Errorf("expected url: %q, got: %q", tt.want, u) + } + }) + } +} diff --git a/cmd/tailscale/generate.go b/cmd/tailscale/generate.go index fa38b3704..5c2e9be91 100644 --- a/cmd/tailscale/generate.go +++ b/cmd/tailscale/generate.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -//go:generate go run tailscale.com/cmd/mkmanifest amd64 windows-manifest.xml manifest_windows_amd64.syso -//go:generate go run tailscale.com/cmd/mkmanifest 386 windows-manifest.xml manifest_windows_386.syso -//go:generate go run tailscale.com/cmd/mkmanifest arm64 windows-manifest.xml manifest_windows_arm64.syso +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +//go:generate go run tailscale.com/cmd/mkmanifest amd64 windows-manifest.xml manifest_windows_amd64.syso +//go:generate go run tailscale.com/cmd/mkmanifest 386 windows-manifest.xml manifest_windows_386.syso +//go:generate go run tailscale.com/cmd/mkmanifest arm64 windows-manifest.xml manifest_windows_arm64.syso diff --git a/cmd/tailscale/tailscale.go b/cmd/tailscale/tailscale.go index 1848d6508..f6adb6c19 100644 --- a/cmd/tailscale/tailscale.go +++ b/cmd/tailscale/tailscale.go @@ -1,26 +1,26 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The tailscale command is the Tailscale command-line client. It interacts -// with the tailscaled node agent. -package main // import "tailscale.com/cmd/tailscale" - -import ( - "fmt" - "os" - "path/filepath" - "strings" - - "tailscale.com/cmd/tailscale/cli" -) - -func main() { - args := os.Args[1:] - if name, _ := os.Executable(); strings.HasSuffix(filepath.Base(name), ".cgi") { - args = []string{"web", "-cgi"} - } - if err := cli.Run(args); err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The tailscale command is the Tailscale command-line client. It interacts +// with the tailscaled node agent. +package main // import "tailscale.com/cmd/tailscale" + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "tailscale.com/cmd/tailscale/cli" +) + +func main() { + args := os.Args[1:] + if name, _ := os.Executable(); strings.HasSuffix(filepath.Base(name), ".cgi") { + args = []string{"web", "-cgi"} + } + if err := cli.Run(args); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} diff --git a/cmd/tailscale/windows-manifest.xml b/cmd/tailscale/windows-manifest.xml index 5eaa54fa5..6c5f46058 100644 --- a/cmd/tailscale/windows-manifest.xml +++ b/cmd/tailscale/windows-manifest.xml @@ -1,13 +1,13 @@ - - - - - - - - - - - - - + + + + + + + + + + + + + diff --git a/cmd/tailscaled/childproc/childproc.go b/cmd/tailscaled/childproc/childproc.go index 068015c59..cc83a06c6 100644 --- a/cmd/tailscaled/childproc/childproc.go +++ b/cmd/tailscaled/childproc/childproc.go @@ -1,19 +1,19 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package childproc allows other packages to register "tailscaled be-child" -// child process hook code. This avoids duplicating build tags in the -// tailscaled package. Instead, the code that needs to fork/exec the self -// executable (when it's tailscaled) can instead register the code -// they want to run. -package childproc - -var Code = map[string]func([]string) error{} - -// Add registers code f to run as 'tailscaled be-child [args]'. -func Add(typ string, f func(args []string) error) { - if _, dup := Code[typ]; dup { - panic("dup hook " + typ) - } - Code[typ] = f -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package childproc allows other packages to register "tailscaled be-child" +// child process hook code. This avoids duplicating build tags in the +// tailscaled package. Instead, the code that needs to fork/exec the self +// executable (when it's tailscaled) can instead register the code +// they want to run. +package childproc + +var Code = map[string]func([]string) error{} + +// Add registers code f to run as 'tailscaled be-child [args]'. +func Add(typ string, f func(args []string) error) { + if _, dup := Code[typ]; dup { + panic("dup hook " + typ) + } + Code[typ] = f +} diff --git a/cmd/tailscaled/generate.go b/cmd/tailscaled/generate.go index fa38b3704..5c2e9be91 100644 --- a/cmd/tailscaled/generate.go +++ b/cmd/tailscaled/generate.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -//go:generate go run tailscale.com/cmd/mkmanifest amd64 windows-manifest.xml manifest_windows_amd64.syso -//go:generate go run tailscale.com/cmd/mkmanifest 386 windows-manifest.xml manifest_windows_386.syso -//go:generate go run tailscale.com/cmd/mkmanifest arm64 windows-manifest.xml manifest_windows_arm64.syso +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +//go:generate go run tailscale.com/cmd/mkmanifest amd64 windows-manifest.xml manifest_windows_amd64.syso +//go:generate go run tailscale.com/cmd/mkmanifest 386 windows-manifest.xml manifest_windows_386.syso +//go:generate go run tailscale.com/cmd/mkmanifest arm64 windows-manifest.xml manifest_windows_arm64.syso diff --git a/cmd/tailscaled/install_darwin.go b/cmd/tailscaled/install_darwin.go index 9013b39ba..05e5eaed8 100644 --- a/cmd/tailscaled/install_darwin.go +++ b/cmd/tailscaled/install_darwin.go @@ -1,199 +1,199 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 - -package main - -import ( - "errors" - "fmt" - "io" - "io/fs" - "os" - "os/exec" - "path/filepath" -) - -func init() { - installSystemDaemon = installSystemDaemonDarwin - uninstallSystemDaemon = uninstallSystemDaemonDarwin -} - -// darwinLaunchdPlist is the launchd.plist that's written to -// /Library/LaunchDaemons/com.tailscale.tailscaled.plist or (in the -// future) a user-specific location. -// -// See man launchd.plist. -const darwinLaunchdPlist = ` - - - - - - Label - com.tailscale.tailscaled - - ProgramArguments - - /usr/local/bin/tailscaled - - - RunAtLoad - - - - -` - -const sysPlist = "/Library/LaunchDaemons/com.tailscale.tailscaled.plist" -const targetBin = "/usr/local/bin/tailscaled" -const service = "com.tailscale.tailscaled" - -func uninstallSystemDaemonDarwin(args []string) (ret error) { - if len(args) > 0 { - return errors.New("uninstall subcommand takes no arguments") - } - - plist, err := exec.Command("launchctl", "list", "com.tailscale.tailscaled").Output() - _ = plist // parse it? https://github.com/DHowett/go-plist if we need something. - running := err == nil - - if running { - out, err := exec.Command("launchctl", "stop", "com.tailscale.tailscaled").CombinedOutput() - if err != nil { - fmt.Printf("launchctl stop com.tailscale.tailscaled: %v, %s\n", err, out) - ret = err - } - out, err = exec.Command("launchctl", "unload", sysPlist).CombinedOutput() - if err != nil { - fmt.Printf("launchctl unload %s: %v, %s\n", sysPlist, err, out) - if ret == nil { - ret = err - } - } - } - - if err := os.Remove(sysPlist); err != nil { - if os.IsNotExist(err) { - err = nil - } - if ret == nil { - ret = err - } - } - - // Do not delete targetBin if it's a symlink, which happens if it was installed via - // Homebrew. - if isSymlink(targetBin) { - return ret - } - - if err := os.Remove(targetBin); err != nil { - if os.IsNotExist(err) { - err = nil - } - if ret == nil { - ret = err - } - } - return ret -} - -func installSystemDaemonDarwin(args []string) (err error) { - if len(args) > 0 { - return errors.New("install subcommand takes no arguments") - } - defer func() { - if err != nil && os.Getuid() != 0 { - err = fmt.Errorf("%w; try running tailscaled with sudo", err) - } - }() - - // Best effort: - uninstallSystemDaemonDarwin(nil) - - exe, err := os.Executable() - if err != nil { - return fmt.Errorf("failed to find our own executable path: %w", err) - } - - same, err := sameFile(exe, targetBin) - if err != nil { - return err - } - - // Do not overwrite targetBin with the binary file if it it's already - // pointing to it. This is primarily to handle Homebrew that writes - // /usr/local/bin/tailscaled is a symlink to the actual binary. - if !same { - if err := copyBinary(exe, targetBin); err != nil { - return err - } - } - if err := os.WriteFile(sysPlist, []byte(darwinLaunchdPlist), 0700); err != nil { - return err - } - - if out, err := exec.Command("launchctl", "load", sysPlist).CombinedOutput(); err != nil { - return fmt.Errorf("error running launchctl load %s: %v, %s", sysPlist, err, out) - } - - if out, err := exec.Command("launchctl", "start", service).CombinedOutput(); err != nil { - return fmt.Errorf("error running launchctl start %s: %v, %s", service, err, out) - } - - return nil -} - -// copyBinary copies binary file `src` into `dst`. -func copyBinary(src, dst string) error { - if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { - return err - } - tmpBin := dst + ".tmp" - f, err := os.Create(tmpBin) - if err != nil { - return err - } - srcf, err := os.Open(src) - if err != nil { - f.Close() - return err - } - _, err = io.Copy(f, srcf) - srcf.Close() - if err != nil { - f.Close() - return err - } - if err := f.Close(); err != nil { - return err - } - if err := os.Chmod(tmpBin, 0755); err != nil { - return err - } - if err := os.Rename(tmpBin, dst); err != nil { - return err - } - - return nil -} - -func isSymlink(path string) bool { - fi, err := os.Lstat(path) - return err == nil && (fi.Mode()&os.ModeSymlink == os.ModeSymlink) -} - -// sameFile returns true if both file paths exist and resolve to the same file. -func sameFile(path1, path2 string) (bool, error) { - dst1, err := filepath.EvalSymlinks(path1) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return false, fmt.Errorf("EvalSymlinks(%s): %w", path1, err) - } - dst2, err := filepath.EvalSymlinks(path2) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return false, fmt.Errorf("EvalSymlinks(%s): %w", path2, err) - } - return dst1 == dst2, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +package main + +import ( + "errors" + "fmt" + "io" + "io/fs" + "os" + "os/exec" + "path/filepath" +) + +func init() { + installSystemDaemon = installSystemDaemonDarwin + uninstallSystemDaemon = uninstallSystemDaemonDarwin +} + +// darwinLaunchdPlist is the launchd.plist that's written to +// /Library/LaunchDaemons/com.tailscale.tailscaled.plist or (in the +// future) a user-specific location. +// +// See man launchd.plist. +const darwinLaunchdPlist = ` + + + + + + Label + com.tailscale.tailscaled + + ProgramArguments + + /usr/local/bin/tailscaled + + + RunAtLoad + + + + +` + +const sysPlist = "/Library/LaunchDaemons/com.tailscale.tailscaled.plist" +const targetBin = "/usr/local/bin/tailscaled" +const service = "com.tailscale.tailscaled" + +func uninstallSystemDaemonDarwin(args []string) (ret error) { + if len(args) > 0 { + return errors.New("uninstall subcommand takes no arguments") + } + + plist, err := exec.Command("launchctl", "list", "com.tailscale.tailscaled").Output() + _ = plist // parse it? https://github.com/DHowett/go-plist if we need something. + running := err == nil + + if running { + out, err := exec.Command("launchctl", "stop", "com.tailscale.tailscaled").CombinedOutput() + if err != nil { + fmt.Printf("launchctl stop com.tailscale.tailscaled: %v, %s\n", err, out) + ret = err + } + out, err = exec.Command("launchctl", "unload", sysPlist).CombinedOutput() + if err != nil { + fmt.Printf("launchctl unload %s: %v, %s\n", sysPlist, err, out) + if ret == nil { + ret = err + } + } + } + + if err := os.Remove(sysPlist); err != nil { + if os.IsNotExist(err) { + err = nil + } + if ret == nil { + ret = err + } + } + + // Do not delete targetBin if it's a symlink, which happens if it was installed via + // Homebrew. + if isSymlink(targetBin) { + return ret + } + + if err := os.Remove(targetBin); err != nil { + if os.IsNotExist(err) { + err = nil + } + if ret == nil { + ret = err + } + } + return ret +} + +func installSystemDaemonDarwin(args []string) (err error) { + if len(args) > 0 { + return errors.New("install subcommand takes no arguments") + } + defer func() { + if err != nil && os.Getuid() != 0 { + err = fmt.Errorf("%w; try running tailscaled with sudo", err) + } + }() + + // Best effort: + uninstallSystemDaemonDarwin(nil) + + exe, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to find our own executable path: %w", err) + } + + same, err := sameFile(exe, targetBin) + if err != nil { + return err + } + + // Do not overwrite targetBin with the binary file if it it's already + // pointing to it. This is primarily to handle Homebrew that writes + // /usr/local/bin/tailscaled is a symlink to the actual binary. + if !same { + if err := copyBinary(exe, targetBin); err != nil { + return err + } + } + if err := os.WriteFile(sysPlist, []byte(darwinLaunchdPlist), 0700); err != nil { + return err + } + + if out, err := exec.Command("launchctl", "load", sysPlist).CombinedOutput(); err != nil { + return fmt.Errorf("error running launchctl load %s: %v, %s", sysPlist, err, out) + } + + if out, err := exec.Command("launchctl", "start", service).CombinedOutput(); err != nil { + return fmt.Errorf("error running launchctl start %s: %v, %s", service, err, out) + } + + return nil +} + +// copyBinary copies binary file `src` into `dst`. +func copyBinary(src, dst string) error { + if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { + return err + } + tmpBin := dst + ".tmp" + f, err := os.Create(tmpBin) + if err != nil { + return err + } + srcf, err := os.Open(src) + if err != nil { + f.Close() + return err + } + _, err = io.Copy(f, srcf) + srcf.Close() + if err != nil { + f.Close() + return err + } + if err := f.Close(); err != nil { + return err + } + if err := os.Chmod(tmpBin, 0755); err != nil { + return err + } + if err := os.Rename(tmpBin, dst); err != nil { + return err + } + + return nil +} + +func isSymlink(path string) bool { + fi, err := os.Lstat(path) + return err == nil && (fi.Mode()&os.ModeSymlink == os.ModeSymlink) +} + +// sameFile returns true if both file paths exist and resolve to the same file. +func sameFile(path1, path2 string) (bool, error) { + dst1, err := filepath.EvalSymlinks(path1) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return false, fmt.Errorf("EvalSymlinks(%s): %w", path1, err) + } + dst2, err := filepath.EvalSymlinks(path2) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return false, fmt.Errorf("EvalSymlinks(%s): %w", path2, err) + } + return dst1 == dst2, nil +} diff --git a/cmd/tailscaled/install_windows.go b/cmd/tailscaled/install_windows.go index 9e39c8ab3..c36418642 100644 --- a/cmd/tailscaled/install_windows.go +++ b/cmd/tailscaled/install_windows.go @@ -1,124 +1,124 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 - -package main - -import ( - "context" - "errors" - "fmt" - "os" - "time" - - "golang.org/x/sys/windows" - "golang.org/x/sys/windows/svc" - "golang.org/x/sys/windows/svc/mgr" - "tailscale.com/logtail/backoff" - "tailscale.com/types/logger" - "tailscale.com/util/osshare" -) - -func init() { - installSystemDaemon = installSystemDaemonWindows - uninstallSystemDaemon = uninstallSystemDaemonWindows -} - -func installSystemDaemonWindows(args []string) (err error) { - m, err := mgr.Connect() - if err != nil { - return fmt.Errorf("failed to connect to Windows service manager: %v", err) - } - - service, err := m.OpenService(serviceName) - if err == nil { - service.Close() - return fmt.Errorf("service %q is already installed", serviceName) - } - - // no such service; proceed to install the service. - - exe, err := os.Executable() - if err != nil { - return err - } - - c := mgr.Config{ - ServiceType: windows.SERVICE_WIN32_OWN_PROCESS, - StartType: mgr.StartAutomatic, - ErrorControl: mgr.ErrorNormal, - DisplayName: serviceName, - Description: "Connects this computer to others on the Tailscale network.", - } - - service, err = m.CreateService(serviceName, exe, c) - if err != nil { - return fmt.Errorf("failed to create %q service: %v", serviceName, err) - } - defer service.Close() - - // Exponential backoff is often too aggressive, so use (mostly) - // squares instead. - ra := []mgr.RecoveryAction{ - {mgr.ServiceRestart, 1 * time.Second}, - {mgr.ServiceRestart, 2 * time.Second}, - {mgr.ServiceRestart, 4 * time.Second}, - {mgr.ServiceRestart, 9 * time.Second}, - {mgr.ServiceRestart, 16 * time.Second}, - {mgr.ServiceRestart, 25 * time.Second}, - {mgr.ServiceRestart, 36 * time.Second}, - {mgr.ServiceRestart, 49 * time.Second}, - {mgr.ServiceRestart, 64 * time.Second}, - } - const resetPeriodSecs = 60 - err = service.SetRecoveryActions(ra, resetPeriodSecs) - if err != nil { - return fmt.Errorf("failed to set service recovery actions: %v", err) - } - - return nil -} - -func uninstallSystemDaemonWindows(args []string) (ret error) { - // Remove file sharing from Windows shell (noop in non-windows) - osshare.SetFileSharingEnabled(false, logger.Discard) - - m, err := mgr.Connect() - if err != nil { - return fmt.Errorf("failed to connect to Windows service manager: %v", err) - } - defer m.Disconnect() - - service, err := m.OpenService(serviceName) - if err != nil { - return fmt.Errorf("failed to open %q service: %v", serviceName, err) - } - - st, err := service.Query() - if err != nil { - service.Close() - return fmt.Errorf("failed to query service state: %v", err) - } - if st.State != svc.Stopped { - service.Control(svc.Stop) - } - err = service.Delete() - service.Close() - if err != nil { - return fmt.Errorf("failed to delete service: %v", err) - } - - bo := backoff.NewBackoff("uninstall", logger.Discard, 30*time.Second) - end := time.Now().Add(15 * time.Second) - for time.Until(end) > 0 { - service, err = m.OpenService(serviceName) - if err != nil { - // service is no longer openable; success! - break - } - service.Close() - bo.BackOff(context.Background(), errors.New("service not deleted")) - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +package main + +import ( + "context" + "errors" + "fmt" + "os" + "time" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" + "tailscale.com/logtail/backoff" + "tailscale.com/types/logger" + "tailscale.com/util/osshare" +) + +func init() { + installSystemDaemon = installSystemDaemonWindows + uninstallSystemDaemon = uninstallSystemDaemonWindows +} + +func installSystemDaemonWindows(args []string) (err error) { + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to Windows service manager: %v", err) + } + + service, err := m.OpenService(serviceName) + if err == nil { + service.Close() + return fmt.Errorf("service %q is already installed", serviceName) + } + + // no such service; proceed to install the service. + + exe, err := os.Executable() + if err != nil { + return err + } + + c := mgr.Config{ + ServiceType: windows.SERVICE_WIN32_OWN_PROCESS, + StartType: mgr.StartAutomatic, + ErrorControl: mgr.ErrorNormal, + DisplayName: serviceName, + Description: "Connects this computer to others on the Tailscale network.", + } + + service, err = m.CreateService(serviceName, exe, c) + if err != nil { + return fmt.Errorf("failed to create %q service: %v", serviceName, err) + } + defer service.Close() + + // Exponential backoff is often too aggressive, so use (mostly) + // squares instead. + ra := []mgr.RecoveryAction{ + {mgr.ServiceRestart, 1 * time.Second}, + {mgr.ServiceRestart, 2 * time.Second}, + {mgr.ServiceRestart, 4 * time.Second}, + {mgr.ServiceRestart, 9 * time.Second}, + {mgr.ServiceRestart, 16 * time.Second}, + {mgr.ServiceRestart, 25 * time.Second}, + {mgr.ServiceRestart, 36 * time.Second}, + {mgr.ServiceRestart, 49 * time.Second}, + {mgr.ServiceRestart, 64 * time.Second}, + } + const resetPeriodSecs = 60 + err = service.SetRecoveryActions(ra, resetPeriodSecs) + if err != nil { + return fmt.Errorf("failed to set service recovery actions: %v", err) + } + + return nil +} + +func uninstallSystemDaemonWindows(args []string) (ret error) { + // Remove file sharing from Windows shell (noop in non-windows) + osshare.SetFileSharingEnabled(false, logger.Discard) + + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to Windows service manager: %v", err) + } + defer m.Disconnect() + + service, err := m.OpenService(serviceName) + if err != nil { + return fmt.Errorf("failed to open %q service: %v", serviceName, err) + } + + st, err := service.Query() + if err != nil { + service.Close() + return fmt.Errorf("failed to query service state: %v", err) + } + if st.State != svc.Stopped { + service.Control(svc.Stop) + } + err = service.Delete() + service.Close() + if err != nil { + return fmt.Errorf("failed to delete service: %v", err) + } + + bo := backoff.NewBackoff("uninstall", logger.Discard, 30*time.Second) + end := time.Now().Add(15 * time.Second) + for time.Until(end) > 0 { + service, err = m.OpenService(serviceName) + if err != nil { + // service is no longer openable; success! + break + } + service.Close() + bo.BackOff(context.Background(), errors.New("service not deleted")) + } + return nil +} diff --git a/cmd/tailscaled/proxy.go b/cmd/tailscaled/proxy.go index 109ad029d..a91c62bfa 100644 --- a/cmd/tailscaled/proxy.go +++ b/cmd/tailscaled/proxy.go @@ -1,80 +1,80 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 - -// HTTP proxy code - -package main - -import ( - "context" - "io" - "net" - "net/http" - "net/http/httputil" - "strings" -) - -// httpProxyHandler returns an HTTP proxy http.Handler using the -// provided backend dialer. -func httpProxyHandler(dialer func(ctx context.Context, netw, addr string) (net.Conn, error)) http.Handler { - rp := &httputil.ReverseProxy{ - Director: func(r *http.Request) {}, // no change - Transport: &http.Transport{ - DialContext: dialer, - }, - } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != "CONNECT" { - backURL := r.RequestURI - if strings.HasPrefix(backURL, "/") || backURL == "*" { - http.Error(w, "bogus RequestURI; must be absolute URL or CONNECT", 400) - return - } - rp.ServeHTTP(w, r) - return - } - - // CONNECT support: - - dst := r.RequestURI - c, err := dialer(r.Context(), "tcp", dst) - if err != nil { - w.Header().Set("Tailscale-Connect-Error", err.Error()) - http.Error(w, err.Error(), 500) - return - } - defer c.Close() - - cc, ccbuf, err := w.(http.Hijacker).Hijack() - if err != nil { - http.Error(w, err.Error(), 500) - return - } - defer cc.Close() - - io.WriteString(cc, "HTTP/1.1 200 OK\r\n\r\n") - - var clientSrc io.Reader = ccbuf - if ccbuf.Reader.Buffered() == 0 { - // In the common case (with no - // buffered data), read directly from - // the underlying client connection to - // save some memory, letting the - // bufio.Reader/Writer get GC'ed. - clientSrc = cc - } - - errc := make(chan error, 1) - go func() { - _, err := io.Copy(cc, c) - errc <- err - }() - go func() { - _, err := io.Copy(c, clientSrc) - errc <- err - }() - <-errc - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 + +// HTTP proxy code + +package main + +import ( + "context" + "io" + "net" + "net/http" + "net/http/httputil" + "strings" +) + +// httpProxyHandler returns an HTTP proxy http.Handler using the +// provided backend dialer. +func httpProxyHandler(dialer func(ctx context.Context, netw, addr string) (net.Conn, error)) http.Handler { + rp := &httputil.ReverseProxy{ + Director: func(r *http.Request) {}, // no change + Transport: &http.Transport{ + DialContext: dialer, + }, + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "CONNECT" { + backURL := r.RequestURI + if strings.HasPrefix(backURL, "/") || backURL == "*" { + http.Error(w, "bogus RequestURI; must be absolute URL or CONNECT", 400) + return + } + rp.ServeHTTP(w, r) + return + } + + // CONNECT support: + + dst := r.RequestURI + c, err := dialer(r.Context(), "tcp", dst) + if err != nil { + w.Header().Set("Tailscale-Connect-Error", err.Error()) + http.Error(w, err.Error(), 500) + return + } + defer c.Close() + + cc, ccbuf, err := w.(http.Hijacker).Hijack() + if err != nil { + http.Error(w, err.Error(), 500) + return + } + defer cc.Close() + + io.WriteString(cc, "HTTP/1.1 200 OK\r\n\r\n") + + var clientSrc io.Reader = ccbuf + if ccbuf.Reader.Buffered() == 0 { + // In the common case (with no + // buffered data), read directly from + // the underlying client connection to + // save some memory, letting the + // bufio.Reader/Writer get GC'ed. + clientSrc = cc + } + + errc := make(chan error, 1) + go func() { + _, err := io.Copy(cc, c) + errc <- err + }() + go func() { + _, err := io.Copy(c, clientSrc) + errc <- err + }() + <-errc + }) +} diff --git a/cmd/tailscaled/sigpipe.go b/cmd/tailscaled/sigpipe.go index 695a88024..2fcdab2a4 100644 --- a/cmd/tailscaled/sigpipe.go +++ b/cmd/tailscaled/sigpipe.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.21 && !plan9 - -package main - -import "syscall" - -func init() { - sigPipe = syscall.SIGPIPE -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.21 && !plan9 + +package main + +import "syscall" + +func init() { + sigPipe = syscall.SIGPIPE +} diff --git a/cmd/tailscaled/tailscaled.defaults b/cmd/tailscaled/tailscaled.defaults index 693a6190b..e8384a4f8 100644 --- a/cmd/tailscaled/tailscaled.defaults +++ b/cmd/tailscaled/tailscaled.defaults @@ -1,8 +1,8 @@ -# Set the port to listen on for incoming VPN packets. -# Remote nodes will automatically be informed about the new port number, -# but you might want to configure this in order to set external firewall -# settings. -PORT="41641" - -# Extra flags you might want to pass to tailscaled. -FLAGS="" +# Set the port to listen on for incoming VPN packets. +# Remote nodes will automatically be informed about the new port number, +# but you might want to configure this in order to set external firewall +# settings. +PORT="41641" + +# Extra flags you might want to pass to tailscaled. +FLAGS="" diff --git a/cmd/tailscaled/tailscaled.openrc b/cmd/tailscaled/tailscaled.openrc index 6193247ce..309d70f23 100755 --- a/cmd/tailscaled/tailscaled.openrc +++ b/cmd/tailscaled/tailscaled.openrc @@ -1,25 +1,25 @@ -#!/sbin/openrc-run - -set -a -source /etc/default/tailscaled -set +a - -command="/usr/sbin/tailscaled" -command_args="--state=/var/lib/tailscale/tailscaled.state --port=$PORT --socket=/var/run/tailscale/tailscaled.sock $FLAGS" -command_background=true -pidfile="/run/tailscaled.pid" -start_stop_daemon_args="-1 /var/log/tailscaled.log -2 /var/log/tailscaled.log" - -depend() { - need net -} - -start_pre() { - mkdir -p /var/run/tailscale - mkdir -p /var/lib/tailscale - $command --cleanup -} - -stop_post() { - $command --cleanup -} +#!/sbin/openrc-run + +set -a +source /etc/default/tailscaled +set +a + +command="/usr/sbin/tailscaled" +command_args="--state=/var/lib/tailscale/tailscaled.state --port=$PORT --socket=/var/run/tailscale/tailscaled.sock $FLAGS" +command_background=true +pidfile="/run/tailscaled.pid" +start_stop_daemon_args="-1 /var/log/tailscaled.log -2 /var/log/tailscaled.log" + +depend() { + need net +} + +start_pre() { + mkdir -p /var/run/tailscale + mkdir -p /var/lib/tailscale + $command --cleanup +} + +stop_post() { + $command --cleanup +} diff --git a/cmd/tailscaled/tailscaled_bird.go b/cmd/tailscaled/tailscaled_bird.go index 885f552cb..c76f77bec 100644 --- a/cmd/tailscaled/tailscaled_bird.go +++ b/cmd/tailscaled/tailscaled_bird.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.19 && (linux || darwin || freebsd || openbsd) && !ts_omit_bird - -package main - -import ( - "tailscale.com/chirp" - "tailscale.com/wgengine" -) - -func init() { - createBIRDClient = func(ctlSocket string) (wgengine.BIRDClient, error) { - return chirp.New(ctlSocket) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.19 && (linux || darwin || freebsd || openbsd) && !ts_omit_bird + +package main + +import ( + "tailscale.com/chirp" + "tailscale.com/wgengine" +) + +func init() { + createBIRDClient = func(ctlSocket string) (wgengine.BIRDClient, error) { + return chirp.New(ctlSocket) + } +} diff --git a/cmd/tailscaled/tailscaled_notwindows.go b/cmd/tailscaled/tailscaled_notwindows.go index b0a7c1598..d5361cf28 100644 --- a/cmd/tailscaled/tailscaled_notwindows.go +++ b/cmd/tailscaled/tailscaled_notwindows.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && go1.19 - -package main // import "tailscale.com/cmd/tailscaled" - -import "tailscale.com/logpolicy" - -func isWindowsService() bool { return false } - -func runWindowsService(pol *logpolicy.Policy) error { panic("unreachable") } - -func beWindowsSubprocess() bool { return false } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && go1.19 + +package main // import "tailscale.com/cmd/tailscaled" + +import "tailscale.com/logpolicy" + +func isWindowsService() bool { return false } + +func runWindowsService(pol *logpolicy.Policy) error { panic("unreachable") } + +func beWindowsSubprocess() bool { return false } diff --git a/cmd/tailscaled/windows-manifest.xml b/cmd/tailscaled/windows-manifest.xml index 5eaa54fa5..6c5f46058 100644 --- a/cmd/tailscaled/windows-manifest.xml +++ b/cmd/tailscaled/windows-manifest.xml @@ -1,13 +1,13 @@ - - - - - - - - - - - - - + + + + + + + + + + + + + diff --git a/cmd/tailscaled/with_cli.go b/cmd/tailscaled/with_cli.go index f191fdb45..a8554eb8c 100644 --- a/cmd/tailscaled/with_cli.go +++ b/cmd/tailscaled/with_cli.go @@ -1,23 +1,23 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ts_include_cli - -package main - -import ( - "fmt" - "os" - - "tailscale.com/cmd/tailscale/cli" -) - -func init() { - beCLI = func() { - args := os.Args[1:] - if err := cli.Run(args); err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_include_cli + +package main + +import ( + "fmt" + "os" + + "tailscale.com/cmd/tailscale/cli" +) + +func init() { + beCLI = func() { + args := os.Args[1:] + if err := cli.Run(args); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + } +} diff --git a/cmd/testwrapper/args_test.go b/cmd/testwrapper/args_test.go index f7f30a7eb..10063d7bc 100644 --- a/cmd/testwrapper/args_test.go +++ b/cmd/testwrapper/args_test.go @@ -1,97 +1,97 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "slices" - "testing" -) - -func TestSplitArgs(t *testing.T) { - tests := []struct { - name string - in []string - pre, pkgs, post []string - }{ - { - name: "empty", - }, - { - name: "all", - in: []string{"-v", "pkg1", "pkg2", "-run", "TestFoo", "-timeout=20s"}, - pre: []string{"-v"}, - pkgs: []string{"pkg1", "pkg2"}, - post: []string{"-run", "TestFoo", "-timeout=20s"}, - }, - { - name: "only_pkgs", - in: []string{"./..."}, - pkgs: []string{"./..."}, - }, - { - name: "pkgs_and_post", - in: []string{"pkg1", "-run", "TestFoo"}, - pkgs: []string{"pkg1"}, - post: []string{"-run", "TestFoo"}, - }, - { - name: "pkgs_and_post", - in: []string{"-v", "pkg2"}, - pre: []string{"-v"}, - pkgs: []string{"pkg2"}, - }, - { - name: "only_args", - in: []string{"-v", "-run=TestFoo"}, - pre: []string{"-run", "TestFoo", "-v"}, // sorted - }, - { - name: "space_in_pre_arg", - in: []string{"-run", "TestFoo", "./cmd/testwrapper"}, - pre: []string{"-run", "TestFoo"}, - pkgs: []string{"./cmd/testwrapper"}, - }, - { - name: "space_in_arg", - in: []string{"-exec", "sudo -E", "./cmd/testwrapper"}, - pre: []string{"-exec", "sudo -E"}, - pkgs: []string{"./cmd/testwrapper"}, - }, - { - name: "test-arg", - in: []string{"-exec", "sudo -E", "./cmd/testwrapper", "--", "--some-flag"}, - pre: []string{"-exec", "sudo -E"}, - pkgs: []string{"./cmd/testwrapper"}, - post: []string{"--", "--some-flag"}, - }, - { - name: "dupe-args", - in: []string{"-v", "-v", "-race", "-race", "./cmd/testwrapper", "--", "--some-flag"}, - pre: []string{"-race", "-v"}, - pkgs: []string{"./cmd/testwrapper"}, - post: []string{"--", "--some-flag"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - pre, pkgs, post, err := splitArgs(tt.in) - if err != nil { - t.Fatal(err) - } - if !slices.Equal(pre, tt.pre) { - t.Errorf("pre = %q; want %q", pre, tt.pre) - } - if !slices.Equal(pkgs, tt.pkgs) { - t.Errorf("pattern = %q; want %q", pkgs, tt.pkgs) - } - if !slices.Equal(post, tt.post) { - t.Errorf("post = %q; want %q", post, tt.post) - } - if t.Failed() { - t.Logf("SplitArgs(%q) = %q %q %q", tt.in, pre, pkgs, post) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "slices" + "testing" +) + +func TestSplitArgs(t *testing.T) { + tests := []struct { + name string + in []string + pre, pkgs, post []string + }{ + { + name: "empty", + }, + { + name: "all", + in: []string{"-v", "pkg1", "pkg2", "-run", "TestFoo", "-timeout=20s"}, + pre: []string{"-v"}, + pkgs: []string{"pkg1", "pkg2"}, + post: []string{"-run", "TestFoo", "-timeout=20s"}, + }, + { + name: "only_pkgs", + in: []string{"./..."}, + pkgs: []string{"./..."}, + }, + { + name: "pkgs_and_post", + in: []string{"pkg1", "-run", "TestFoo"}, + pkgs: []string{"pkg1"}, + post: []string{"-run", "TestFoo"}, + }, + { + name: "pkgs_and_post", + in: []string{"-v", "pkg2"}, + pre: []string{"-v"}, + pkgs: []string{"pkg2"}, + }, + { + name: "only_args", + in: []string{"-v", "-run=TestFoo"}, + pre: []string{"-run", "TestFoo", "-v"}, // sorted + }, + { + name: "space_in_pre_arg", + in: []string{"-run", "TestFoo", "./cmd/testwrapper"}, + pre: []string{"-run", "TestFoo"}, + pkgs: []string{"./cmd/testwrapper"}, + }, + { + name: "space_in_arg", + in: []string{"-exec", "sudo -E", "./cmd/testwrapper"}, + pre: []string{"-exec", "sudo -E"}, + pkgs: []string{"./cmd/testwrapper"}, + }, + { + name: "test-arg", + in: []string{"-exec", "sudo -E", "./cmd/testwrapper", "--", "--some-flag"}, + pre: []string{"-exec", "sudo -E"}, + pkgs: []string{"./cmd/testwrapper"}, + post: []string{"--", "--some-flag"}, + }, + { + name: "dupe-args", + in: []string{"-v", "-v", "-race", "-race", "./cmd/testwrapper", "--", "--some-flag"}, + pre: []string{"-race", "-v"}, + pkgs: []string{"./cmd/testwrapper"}, + post: []string{"--", "--some-flag"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pre, pkgs, post, err := splitArgs(tt.in) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(pre, tt.pre) { + t.Errorf("pre = %q; want %q", pre, tt.pre) + } + if !slices.Equal(pkgs, tt.pkgs) { + t.Errorf("pattern = %q; want %q", pkgs, tt.pkgs) + } + if !slices.Equal(post, tt.post) { + t.Errorf("post = %q; want %q", post, tt.post) + } + if t.Failed() { + t.Logf("SplitArgs(%q) = %q %q %q", tt.in, pre, pkgs, post) + } + }) + } +} diff --git a/cmd/testwrapper/flakytest/flakytest.go b/cmd/testwrapper/flakytest/flakytest.go index e5e21dd21..494ed080b 100644 --- a/cmd/testwrapper/flakytest/flakytest.go +++ b/cmd/testwrapper/flakytest/flakytest.go @@ -1,44 +1,44 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package flakytest contains test helpers for marking a test as flaky. For -// tests run using cmd/testwrapper, a failed flaky test will cause tests to be -// re-run a few time until they succeed or exceed our iteration limit. -package flakytest - -import ( - "fmt" - "os" - "regexp" - "testing" -) - -// FlakyTestLogMessage is a sentinel value that is printed to stderr when a -// flaky test is marked. This is used by cmd/testwrapper to detect flaky tests -// and retry them. -const FlakyTestLogMessage = "flakytest: this is a known flaky test" - -// FlakeAttemptEnv is an environment variable that is set by cmd/testwrapper -// when a flaky test is being (re)tried. It contains the attempt number, -// starting at 1. -const FlakeAttemptEnv = "TS_TESTWRAPPER_ATTEMPT" - -var issueRegexp = regexp.MustCompile(`\Ahttps://github\.com/tailscale/[a-zA-Z0-9_.-]+/issues/\d+\z`) - -// Mark sets the current test as a flaky test, such that if it fails, it will -// be retried a few times on failure. issue must be a GitHub issue that tracks -// the status of the flaky test being marked, of the format: -// -// https://github.com/tailscale/myRepo-H3re/issues/12345 -func Mark(t testing.TB, issue string) { - if !issueRegexp.MatchString(issue) { - t.Fatalf("bad issue format: %q", issue) - } - if _, ok := os.LookupEnv(FlakeAttemptEnv); ok { - // We're being run under cmd/testwrapper so send our sentinel message - // to stderr. (We avoid doing this when the env is absent to avoid - // spamming people running tests without the wrapper) - fmt.Fprintf(os.Stderr, "%s: %s\n", FlakyTestLogMessage, issue) - } - t.Logf("flakytest: issue tracking this flaky test: %s", issue) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package flakytest contains test helpers for marking a test as flaky. For +// tests run using cmd/testwrapper, a failed flaky test will cause tests to be +// re-run a few time until they succeed or exceed our iteration limit. +package flakytest + +import ( + "fmt" + "os" + "regexp" + "testing" +) + +// FlakyTestLogMessage is a sentinel value that is printed to stderr when a +// flaky test is marked. This is used by cmd/testwrapper to detect flaky tests +// and retry them. +const FlakyTestLogMessage = "flakytest: this is a known flaky test" + +// FlakeAttemptEnv is an environment variable that is set by cmd/testwrapper +// when a flaky test is being (re)tried. It contains the attempt number, +// starting at 1. +const FlakeAttemptEnv = "TS_TESTWRAPPER_ATTEMPT" + +var issueRegexp = regexp.MustCompile(`\Ahttps://github\.com/tailscale/[a-zA-Z0-9_.-]+/issues/\d+\z`) + +// Mark sets the current test as a flaky test, such that if it fails, it will +// be retried a few times on failure. issue must be a GitHub issue that tracks +// the status of the flaky test being marked, of the format: +// +// https://github.com/tailscale/myRepo-H3re/issues/12345 +func Mark(t testing.TB, issue string) { + if !issueRegexp.MatchString(issue) { + t.Fatalf("bad issue format: %q", issue) + } + if _, ok := os.LookupEnv(FlakeAttemptEnv); ok { + // We're being run under cmd/testwrapper so send our sentinel message + // to stderr. (We avoid doing this when the env is absent to avoid + // spamming people running tests without the wrapper) + fmt.Fprintf(os.Stderr, "%s: %s\n", FlakyTestLogMessage, issue) + } + t.Logf("flakytest: issue tracking this flaky test: %s", issue) +} diff --git a/cmd/testwrapper/flakytest/flakytest_test.go b/cmd/testwrapper/flakytest/flakytest_test.go index 551352f6a..85e77a939 100644 --- a/cmd/testwrapper/flakytest/flakytest_test.go +++ b/cmd/testwrapper/flakytest/flakytest_test.go @@ -1,43 +1,43 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package flakytest - -import ( - "os" - "testing" -) - -func TestIssueFormat(t *testing.T) { - testCases := []struct { - issue string - want bool - }{ - {"https://github.com/tailscale/cOrp/issues/1234", true}, - {"https://github.com/otherproject/corp/issues/1234", false}, - {"https://github.com/tailscale/corp/issues/", false}, - } - for _, testCase := range testCases { - if issueRegexp.MatchString(testCase.issue) != testCase.want { - ss := "" - if !testCase.want { - ss = " not" - } - t.Errorf("expected issueRegexp to%s match %q", ss, testCase.issue) - } - } -} - -// TestFlakeRun is a test that fails when run in the testwrapper -// for the first time, but succeeds on the second run. -// It's used to test whether the testwrapper retries flaky tests. -func TestFlakeRun(t *testing.T) { - Mark(t, "https://github.com/tailscale/tailscale/issues/0") // random issue - e := os.Getenv(FlakeAttemptEnv) - if e == "" { - t.Skip("not running in testwrapper") - } - if e == "1" { - t.Fatal("First run in testwrapper, failing so that test is retried. This is expected.") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package flakytest + +import ( + "os" + "testing" +) + +func TestIssueFormat(t *testing.T) { + testCases := []struct { + issue string + want bool + }{ + {"https://github.com/tailscale/cOrp/issues/1234", true}, + {"https://github.com/otherproject/corp/issues/1234", false}, + {"https://github.com/tailscale/corp/issues/", false}, + } + for _, testCase := range testCases { + if issueRegexp.MatchString(testCase.issue) != testCase.want { + ss := "" + if !testCase.want { + ss = " not" + } + t.Errorf("expected issueRegexp to%s match %q", ss, testCase.issue) + } + } +} + +// TestFlakeRun is a test that fails when run in the testwrapper +// for the first time, but succeeds on the second run. +// It's used to test whether the testwrapper retries flaky tests. +func TestFlakeRun(t *testing.T) { + Mark(t, "https://github.com/tailscale/tailscale/issues/0") // random issue + e := os.Getenv(FlakeAttemptEnv) + if e == "" { + t.Skip("not running in testwrapper") + } + if e == "1" { + t.Fatal("First run in testwrapper, failing so that test is retried. This is expected.") + } +} diff --git a/cmd/tsconnect/.gitignore b/cmd/tsconnect/.gitignore index b791f8e64..13615d121 100644 --- a/cmd/tsconnect/.gitignore +++ b/cmd/tsconnect/.gitignore @@ -1,3 +1,3 @@ -node_modules/ -/dist -/pkg +node_modules/ +/dist +/pkg diff --git a/cmd/tsconnect/README.md b/cmd/tsconnect/README.md index f518f932e..536cd7bbf 100644 --- a/cmd/tsconnect/README.md +++ b/cmd/tsconnect/README.md @@ -1,49 +1,49 @@ -# tsconnect - -The tsconnect command builds and serves the static site that is generated for -the Tailscale Connect JS/WASM client. - -## Development - -To start the development server: - -``` -./tool/go run ./cmd/tsconnect dev -``` - -The site is served at http://localhost:9090/. JavaScript, CSS and Go `wasm` package changes can be picked up with a browser reload. Server-side Go changes require the server to be stopped and restarted. In development mode the state the Tailscale client state is stored in `sessionStorage` and will thus survive page reloads (but not the tab being closed). - -## Deployment - -To build the static assets necessary for serving, run: - -``` -./tool/go run ./cmd/tsconnect build -``` - -To serve them, run: - -``` -./tool/go run ./cmd/tsconnect serve -``` - -By default the build output is placed in the `dist/` directory and embedded in the binary, but this can be controlled by the `-distdir` flag. The `-addr` flag controls the interface and port that the serve listens on. - -# Library / NPM Package - -The client is also available as [an NPM package](https://www.npmjs.com/package/@tailscale/connect). To build it, run: - -``` -./tool/go run ./cmd/tsconnect build-pkg -``` - -That places the output in the `pkg/` directory, which may then be uploaded to a package registry (or installed from the file path directly). - -To do two-sided development (on both the NPM package and code that uses it), run: - -``` -./tool/go run ./cmd/tsconnect dev-pkg - -``` - -This serves the module at http://localhost:9090/pkg/pkg.js and the generated wasm file at http://localhost:9090/pkg/main.wasm. The two files can be used as drop-in replacements for normal imports of the NPM module. +# tsconnect + +The tsconnect command builds and serves the static site that is generated for +the Tailscale Connect JS/WASM client. + +## Development + +To start the development server: + +``` +./tool/go run ./cmd/tsconnect dev +``` + +The site is served at http://localhost:9090/. JavaScript, CSS and Go `wasm` package changes can be picked up with a browser reload. Server-side Go changes require the server to be stopped and restarted. In development mode the state the Tailscale client state is stored in `sessionStorage` and will thus survive page reloads (but not the tab being closed). + +## Deployment + +To build the static assets necessary for serving, run: + +``` +./tool/go run ./cmd/tsconnect build +``` + +To serve them, run: + +``` +./tool/go run ./cmd/tsconnect serve +``` + +By default the build output is placed in the `dist/` directory and embedded in the binary, but this can be controlled by the `-distdir` flag. The `-addr` flag controls the interface and port that the serve listens on. + +# Library / NPM Package + +The client is also available as [an NPM package](https://www.npmjs.com/package/@tailscale/connect). To build it, run: + +``` +./tool/go run ./cmd/tsconnect build-pkg +``` + +That places the output in the `pkg/` directory, which may then be uploaded to a package registry (or installed from the file path directly). + +To do two-sided development (on both the NPM package and code that uses it), run: + +``` +./tool/go run ./cmd/tsconnect dev-pkg + +``` + +This serves the module at http://localhost:9090/pkg/pkg.js and the generated wasm file at http://localhost:9090/pkg/main.wasm. The two files can be used as drop-in replacements for normal imports of the NPM module. diff --git a/cmd/tsconnect/README.pkg.md b/cmd/tsconnect/README.pkg.md index df5799578..df8d66789 100644 --- a/cmd/tsconnect/README.pkg.md +++ b/cmd/tsconnect/README.pkg.md @@ -1,3 +1,3 @@ -# @tailscale/connect - -NPM package that contains a WebAssembly-based Tailscale client, see [the `cmd/tsconnect` directory in the tailscale repo](https://github.com/tailscale/tailscale/tree/main/cmd/tsconnect#library--npm-package) for more details. +# @tailscale/connect + +NPM package that contains a WebAssembly-based Tailscale client, see [the `cmd/tsconnect` directory in the tailscale repo](https://github.com/tailscale/tailscale/tree/main/cmd/tsconnect#library--npm-package) for more details. diff --git a/cmd/tsconnect/build-pkg.go b/cmd/tsconnect/build-pkg.go index 2b6cc9b1f..047504858 100644 --- a/cmd/tsconnect/build-pkg.go +++ b/cmd/tsconnect/build-pkg.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -package main - -import ( - "encoding/json" - "fmt" - "log" - "os" - "path" - - "github.com/tailscale/hujson" - "tailscale.com/util/precompress" - "tailscale.com/version" -) - -func runBuildPkg() { - buildOptions, err := commonPkgSetup(prodMode) - if err != nil { - log.Fatalf("Cannot setup: %v", err) - } - - log.Printf("Linting...\n") - if err := runYarn("lint"); err != nil { - log.Fatalf("Linting failed: %v", err) - } - - if err := cleanDir(*pkgDir); err != nil { - log.Fatalf("Cannot clean %s: %v", *pkgDir, err) - } - - buildOptions.Write = true - buildOptions.MinifyWhitespace = true - buildOptions.MinifyIdentifiers = true - buildOptions.MinifySyntax = true - - runEsbuild(*buildOptions) - - if err := precompressWasm(); err != nil { - log.Fatalf("Could not pre-recompress wasm: %v", err) - } - - log.Printf("Generating types...\n") - if err := runYarn("pkg-types"); err != nil { - log.Fatalf("Type generation failed: %v", err) - } - - if err := updateVersion(); err != nil { - log.Fatalf("Cannot update version: %v", err) - } - - if err := copyReadme(); err != nil { - log.Fatalf("Cannot copy readme: %v", err) - } - - log.Printf("Built package version %s", version.Long()) -} - -func precompressWasm() error { - log.Printf("Pre-compressing main.wasm...\n") - return precompress.Precompress(path.Join(*pkgDir, "main.wasm"), precompress.Options{ - FastCompression: *fastCompression, - }) -} - -func updateVersion() error { - packageJSONBytes, err := os.ReadFile("package.json.tmpl") - if err != nil { - return fmt.Errorf("Could not read package.json: %w", err) - } - - var packageJSON map[string]any - packageJSONBytes, err = hujson.Standardize(packageJSONBytes) - if err != nil { - return fmt.Errorf("Could not standardize template package.json: %w", err) - } - if err := json.Unmarshal(packageJSONBytes, &packageJSON); err != nil { - return fmt.Errorf("Could not unmarshal package.json: %w", err) - } - packageJSON["version"] = version.Long() - - packageJSONBytes, err = json.MarshalIndent(packageJSON, "", " ") - if err != nil { - return fmt.Errorf("Could not marshal package.json: %w", err) - } - - return os.WriteFile(path.Join(*pkgDir, "package.json"), packageJSONBytes, 0644) -} - -func copyReadme() error { - readmeBytes, err := os.ReadFile("README.pkg.md") - if err != nil { - return fmt.Errorf("Could not read README.pkg.md: %w", err) - } - return os.WriteFile(path.Join(*pkgDir, "README.md"), readmeBytes, 0644) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "encoding/json" + "fmt" + "log" + "os" + "path" + + "github.com/tailscale/hujson" + "tailscale.com/util/precompress" + "tailscale.com/version" +) + +func runBuildPkg() { + buildOptions, err := commonPkgSetup(prodMode) + if err != nil { + log.Fatalf("Cannot setup: %v", err) + } + + log.Printf("Linting...\n") + if err := runYarn("lint"); err != nil { + log.Fatalf("Linting failed: %v", err) + } + + if err := cleanDir(*pkgDir); err != nil { + log.Fatalf("Cannot clean %s: %v", *pkgDir, err) + } + + buildOptions.Write = true + buildOptions.MinifyWhitespace = true + buildOptions.MinifyIdentifiers = true + buildOptions.MinifySyntax = true + + runEsbuild(*buildOptions) + + if err := precompressWasm(); err != nil { + log.Fatalf("Could not pre-recompress wasm: %v", err) + } + + log.Printf("Generating types...\n") + if err := runYarn("pkg-types"); err != nil { + log.Fatalf("Type generation failed: %v", err) + } + + if err := updateVersion(); err != nil { + log.Fatalf("Cannot update version: %v", err) + } + + if err := copyReadme(); err != nil { + log.Fatalf("Cannot copy readme: %v", err) + } + + log.Printf("Built package version %s", version.Long()) +} + +func precompressWasm() error { + log.Printf("Pre-compressing main.wasm...\n") + return precompress.Precompress(path.Join(*pkgDir, "main.wasm"), precompress.Options{ + FastCompression: *fastCompression, + }) +} + +func updateVersion() error { + packageJSONBytes, err := os.ReadFile("package.json.tmpl") + if err != nil { + return fmt.Errorf("Could not read package.json: %w", err) + } + + var packageJSON map[string]any + packageJSONBytes, err = hujson.Standardize(packageJSONBytes) + if err != nil { + return fmt.Errorf("Could not standardize template package.json: %w", err) + } + if err := json.Unmarshal(packageJSONBytes, &packageJSON); err != nil { + return fmt.Errorf("Could not unmarshal package.json: %w", err) + } + packageJSON["version"] = version.Long() + + packageJSONBytes, err = json.MarshalIndent(packageJSON, "", " ") + if err != nil { + return fmt.Errorf("Could not marshal package.json: %w", err) + } + + return os.WriteFile(path.Join(*pkgDir, "package.json"), packageJSONBytes, 0644) +} + +func copyReadme() error { + readmeBytes, err := os.ReadFile("README.pkg.md") + if err != nil { + return fmt.Errorf("Could not read README.pkg.md: %w", err) + } + return os.WriteFile(path.Join(*pkgDir, "README.md"), readmeBytes, 0644) +} diff --git a/cmd/tsconnect/dev-pkg.go b/cmd/tsconnect/dev-pkg.go index cb5ebf39e..de534c3b2 100644 --- a/cmd/tsconnect/dev-pkg.go +++ b/cmd/tsconnect/dev-pkg.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -package main - -import ( - "log" -) - -func runDevPkg() { - buildOptions, err := commonPkgSetup(devMode) - if err != nil { - log.Fatalf("Cannot setup: %v", err) - } - runEsbuildServe(*buildOptions) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "log" +) + +func runDevPkg() { + buildOptions, err := commonPkgSetup(devMode) + if err != nil { + log.Fatalf("Cannot setup: %v", err) + } + runEsbuildServe(*buildOptions) +} diff --git a/cmd/tsconnect/dev.go b/cmd/tsconnect/dev.go index 161eb3b86..87b10adaf 100644 --- a/cmd/tsconnect/dev.go +++ b/cmd/tsconnect/dev.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -package main - -import ( - "log" -) - -func runDev() { - buildOptions, err := commonSetup(devMode) - if err != nil { - log.Fatalf("Cannot setup: %v", err) - } - runEsbuildServe(*buildOptions) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "log" +) + +func runDev() { + buildOptions, err := commonSetup(devMode) + if err != nil { + log.Fatalf("Cannot setup: %v", err) + } + runEsbuildServe(*buildOptions) +} diff --git a/cmd/tsconnect/dist/placeholder b/cmd/tsconnect/dist/placeholder index dddaba4d7..4af99d997 100644 --- a/cmd/tsconnect/dist/placeholder +++ b/cmd/tsconnect/dist/placeholder @@ -1,2 +1,2 @@ -This is here to make sure the dist/ directory exists for the go:embed command -in serve.go. +This is here to make sure the dist/ directory exists for the go:embed command +in serve.go. diff --git a/cmd/tsconnect/index.html b/cmd/tsconnect/index.html index 39aa7571a..3db45fdef 100644 --- a/cmd/tsconnect/index.html +++ b/cmd/tsconnect/index.html @@ -1,20 +1,20 @@ - - - - - - Tailscale Connect - - - - - -
-
-

Tailscale Connect

-
Loading…
-
-
- - + + + + + + Tailscale Connect + + + + + +
+
+

Tailscale Connect

+
Loading…
+
+
+ + diff --git a/cmd/tsconnect/package.json b/cmd/tsconnect/package.json index 8ea726cc6..bf4eb7c09 100644 --- a/cmd/tsconnect/package.json +++ b/cmd/tsconnect/package.json @@ -1,25 +1,25 @@ -{ - "name": "tsconnect", - "version": "0.0.1", - "license": "BSD-3-Clause", - "devDependencies": { - "@types/golang-wasm-exec": "^1.15.0", - "@types/qrcode": "^1.4.2", - "dts-bundle-generator": "^6.12.0", - "preact": "^10.10.0", - "qrcode": "^1.5.0", - "tailwindcss": "^3.1.6", - "typescript": "^4.7.4", - "xterm": "^5.1.0", - "xterm-addon-fit": "^0.7.0", - "xterm-addon-web-links": "^0.8.0" - }, - "scripts": { - "lint": "tsc --noEmit", - "pkg-types": "dts-bundle-generator --inline-declare-global=true --no-banner -o pkg/pkg.d.ts src/pkg/pkg.ts" - }, - "prettier": { - "semi": false, - "printWidth": 80 - } -} +{ + "name": "tsconnect", + "version": "0.0.1", + "license": "BSD-3-Clause", + "devDependencies": { + "@types/golang-wasm-exec": "^1.15.0", + "@types/qrcode": "^1.4.2", + "dts-bundle-generator": "^6.12.0", + "preact": "^10.10.0", + "qrcode": "^1.5.0", + "tailwindcss": "^3.1.6", + "typescript": "^4.7.4", + "xterm": "^5.1.0", + "xterm-addon-fit": "^0.7.0", + "xterm-addon-web-links": "^0.8.0" + }, + "scripts": { + "lint": "tsc --noEmit", + "pkg-types": "dts-bundle-generator --inline-declare-global=true --no-banner -o pkg/pkg.d.ts src/pkg/pkg.ts" + }, + "prettier": { + "semi": false, + "printWidth": 80 + } +} diff --git a/cmd/tsconnect/package.json.tmpl b/cmd/tsconnect/package.json.tmpl index 0263bf481..404b896ea 100644 --- a/cmd/tsconnect/package.json.tmpl +++ b/cmd/tsconnect/package.json.tmpl @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Template for the package.json that is generated by the build-pkg command. -// The version number will be replaced by the current Tailscale client version -// number. -{ - "author": "Tailscale Inc.", - "description": "Tailscale Connect SDK", - "license": "BSD-3-Clause", - "name": "@tailscale/connect", - "type": "module", - "main": "./pkg.js", - "types": "./pkg.d.ts", - "version": "AUTO_GENERATED" -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Template for the package.json that is generated by the build-pkg command. +// The version number will be replaced by the current Tailscale client version +// number. +{ + "author": "Tailscale Inc.", + "description": "Tailscale Connect SDK", + "license": "BSD-3-Clause", + "name": "@tailscale/connect", + "type": "module", + "main": "./pkg.js", + "types": "./pkg.d.ts", + "version": "AUTO_GENERATED" +} diff --git a/cmd/tsconnect/serve.go b/cmd/tsconnect/serve.go index 80844bea7..d780bdd57 100644 --- a/cmd/tsconnect/serve.go +++ b/cmd/tsconnect/serve.go @@ -1,144 +1,144 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -package main - -import ( - "bytes" - "embed" - "encoding/json" - "fmt" - "io" - "io/fs" - "log" - "net/http" - "os" - "path" - "time" - - "tailscale.com/tsweb" - "tailscale.com/util/precompress" -) - -//go:embed index.html -var embeddedFS embed.FS - -//go:embed dist/* -var embeddedDistFS embed.FS - -var serveStartTime = time.Now() - -func runServe() { - mux := http.NewServeMux() - - var distFS fs.FS - if *distDir == "./dist" { - var err error - distFS, err = fs.Sub(embeddedDistFS, "dist") - if err != nil { - log.Fatalf("Could not drop dist/ prefix from embedded FS: %v", err) - } - } else { - distFS = os.DirFS(*distDir) - } - - indexBytes, err := generateServeIndex(distFS) - if err != nil { - log.Fatalf("Could not generate index.html: %v", err) - } - mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.ServeContent(w, r, "index.html", serveStartTime, bytes.NewReader(indexBytes)) - })) - mux.Handle("/dist/", http.StripPrefix("/dist/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - handleServeDist(w, r, distFS) - }))) - tsweb.Debugger(mux) - - log.Printf("Listening on %s", *addr) - err = http.ListenAndServe(*addr, mux) - if err != nil { - log.Fatal(err) - } -} - -func generateServeIndex(distFS fs.FS) ([]byte, error) { - log.Printf("Generating index.html...\n") - rawIndexBytes, err := embeddedFS.ReadFile("index.html") - if err != nil { - return nil, fmt.Errorf("Could not read index.html: %w", err) - } - - esbuildMetadataFile, err := distFS.Open("esbuild-metadata.json") - if err != nil { - return nil, fmt.Errorf("Could not open esbuild-metadata.json: %w", err) - } - defer esbuildMetadataFile.Close() - esbuildMetadataBytes, err := io.ReadAll(esbuildMetadataFile) - if err != nil { - return nil, fmt.Errorf("Could not read esbuild-metadata.json: %w", err) - } - var esbuildMetadata EsbuildMetadata - if err := json.Unmarshal(esbuildMetadataBytes, &esbuildMetadata); err != nil { - return nil, fmt.Errorf("Could not parse esbuild-metadata.json: %w", err) - } - entryPointsToHashedDistPaths := make(map[string]string) - mainWasmPath := "" - for outputPath, output := range esbuildMetadata.Outputs { - if output.EntryPoint != "" { - entryPointsToHashedDistPaths[output.EntryPoint] = path.Join("dist", outputPath) - } - if path.Ext(outputPath) == ".wasm" { - for input := range output.Inputs { - if input == "src/main.wasm" { - mainWasmPath = path.Join("dist", outputPath) - break - } - } - } - } - - indexBytes := rawIndexBytes - for entryPointPath, defaultDistPath := range entryPointsToDefaultDistPaths { - hashedDistPath := entryPointsToHashedDistPaths[entryPointPath] - if hashedDistPath != "" { - indexBytes = bytes.ReplaceAll(indexBytes, []byte(defaultDistPath), []byte(hashedDistPath)) - } - } - if mainWasmPath != "" { - mainWasmPrefetch := fmt.Sprintf("\n", mainWasmPath) - indexBytes = bytes.ReplaceAll(indexBytes, []byte(""), []byte(mainWasmPrefetch)) - } - - return indexBytes, nil -} - -var entryPointsToDefaultDistPaths = map[string]string{ - "src/app/index.css": "dist/index.css", - "src/app/index.ts": "dist/index.js", -} - -func handleServeDist(w http.ResponseWriter, r *http.Request, distFS fs.FS) { - path := r.URL.Path - f, err := precompress.OpenPrecompressedFile(w, r, path, distFS) - if err != nil { - http.Error(w, err.Error(), http.StatusNotFound) - return - } - defer f.Close() - - // fs.File does not claim to implement Seeker, but in practice it does. - fSeeker, ok := f.(io.ReadSeeker) - if !ok { - http.Error(w, "Not seekable", http.StatusInternalServerError) - return - } - - // Aggressively cache static assets, since we cache-bust our assets with - // hashed filenames. - w.Header().Set("Cache-Control", "public, max-age=31535996") - w.Header().Set("Vary", "Accept-Encoding") - - http.ServeContent(w, r, path, serveStartTime, fSeeker) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "bytes" + "embed" + "encoding/json" + "fmt" + "io" + "io/fs" + "log" + "net/http" + "os" + "path" + "time" + + "tailscale.com/tsweb" + "tailscale.com/util/precompress" +) + +//go:embed index.html +var embeddedFS embed.FS + +//go:embed dist/* +var embeddedDistFS embed.FS + +var serveStartTime = time.Now() + +func runServe() { + mux := http.NewServeMux() + + var distFS fs.FS + if *distDir == "./dist" { + var err error + distFS, err = fs.Sub(embeddedDistFS, "dist") + if err != nil { + log.Fatalf("Could not drop dist/ prefix from embedded FS: %v", err) + } + } else { + distFS = os.DirFS(*distDir) + } + + indexBytes, err := generateServeIndex(distFS) + if err != nil { + log.Fatalf("Could not generate index.html: %v", err) + } + mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.ServeContent(w, r, "index.html", serveStartTime, bytes.NewReader(indexBytes)) + })) + mux.Handle("/dist/", http.StripPrefix("/dist/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleServeDist(w, r, distFS) + }))) + tsweb.Debugger(mux) + + log.Printf("Listening on %s", *addr) + err = http.ListenAndServe(*addr, mux) + if err != nil { + log.Fatal(err) + } +} + +func generateServeIndex(distFS fs.FS) ([]byte, error) { + log.Printf("Generating index.html...\n") + rawIndexBytes, err := embeddedFS.ReadFile("index.html") + if err != nil { + return nil, fmt.Errorf("Could not read index.html: %w", err) + } + + esbuildMetadataFile, err := distFS.Open("esbuild-metadata.json") + if err != nil { + return nil, fmt.Errorf("Could not open esbuild-metadata.json: %w", err) + } + defer esbuildMetadataFile.Close() + esbuildMetadataBytes, err := io.ReadAll(esbuildMetadataFile) + if err != nil { + return nil, fmt.Errorf("Could not read esbuild-metadata.json: %w", err) + } + var esbuildMetadata EsbuildMetadata + if err := json.Unmarshal(esbuildMetadataBytes, &esbuildMetadata); err != nil { + return nil, fmt.Errorf("Could not parse esbuild-metadata.json: %w", err) + } + entryPointsToHashedDistPaths := make(map[string]string) + mainWasmPath := "" + for outputPath, output := range esbuildMetadata.Outputs { + if output.EntryPoint != "" { + entryPointsToHashedDistPaths[output.EntryPoint] = path.Join("dist", outputPath) + } + if path.Ext(outputPath) == ".wasm" { + for input := range output.Inputs { + if input == "src/main.wasm" { + mainWasmPath = path.Join("dist", outputPath) + break + } + } + } + } + + indexBytes := rawIndexBytes + for entryPointPath, defaultDistPath := range entryPointsToDefaultDistPaths { + hashedDistPath := entryPointsToHashedDistPaths[entryPointPath] + if hashedDistPath != "" { + indexBytes = bytes.ReplaceAll(indexBytes, []byte(defaultDistPath), []byte(hashedDistPath)) + } + } + if mainWasmPath != "" { + mainWasmPrefetch := fmt.Sprintf("\n", mainWasmPath) + indexBytes = bytes.ReplaceAll(indexBytes, []byte(""), []byte(mainWasmPrefetch)) + } + + return indexBytes, nil +} + +var entryPointsToDefaultDistPaths = map[string]string{ + "src/app/index.css": "dist/index.css", + "src/app/index.ts": "dist/index.js", +} + +func handleServeDist(w http.ResponseWriter, r *http.Request, distFS fs.FS) { + path := r.URL.Path + f, err := precompress.OpenPrecompressedFile(w, r, path, distFS) + if err != nil { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + defer f.Close() + + // fs.File does not claim to implement Seeker, but in practice it does. + fSeeker, ok := f.(io.ReadSeeker) + if !ok { + http.Error(w, "Not seekable", http.StatusInternalServerError) + return + } + + // Aggressively cache static assets, since we cache-bust our assets with + // hashed filenames. + w.Header().Set("Cache-Control", "public, max-age=31535996") + w.Header().Set("Vary", "Accept-Encoding") + + http.ServeContent(w, r, path, serveStartTime, fSeeker) +} diff --git a/cmd/tsconnect/src/app/app.tsx b/cmd/tsconnect/src/app/app.tsx index c0aa7a5e8..ee538eaea 100644 --- a/cmd/tsconnect/src/app/app.tsx +++ b/cmd/tsconnect/src/app/app.tsx @@ -1,147 +1,147 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -import { render, Component } from "preact" -import { URLDisplay } from "./url-display" -import { Header } from "./header" -import { GoPanicDisplay } from "./go-panic-display" -import { SSH } from "./ssh" - -type AppState = { - ipn?: IPN - ipnState: IPNState - netMap?: IPNNetMap - browseToURL?: string - goPanicError?: string -} - -class App extends Component<{}, AppState> { - state: AppState = { ipnState: "NoState" } - #goPanicTimeout?: number - - render() { - const { ipn, ipnState, goPanicError, netMap, browseToURL } = this.state - - let goPanicDisplay - if (goPanicError) { - goPanicDisplay = ( - - ) - } - - let urlDisplay - if (browseToURL) { - urlDisplay = - } - - let machineAuthInstructions - if (ipnState === "NeedsMachineAuth") { - machineAuthInstructions = ( -
- An administrator needs to approve this device. -
- ) - } - - const lockedOut = netMap?.lockedOut - let lockedOutInstructions - if (lockedOut) { - lockedOutInstructions = ( -
-

This instance of Tailscale Connect needs to be signed, due to - {" "}tailnet lock{" "} - being enabled on this domain. -

- -

- Run the following command on a device with a trusted tailnet lock key: -

tailscale lock sign {netMap.self.nodeKey}
-

-
- ) - } - - let ssh - if (ipn && ipnState === "Running" && netMap && !lockedOut) { - ssh = - } - - return ( - <> -
- {goPanicDisplay} -
- {urlDisplay} - {machineAuthInstructions} - {lockedOutInstructions} - {ssh} -
- - ) - } - - runWithIPN(ipn: IPN) { - this.setState({ ipn }, () => { - ipn.run({ - notifyState: this.handleIPNState, - notifyNetMap: this.handleNetMap, - notifyBrowseToURL: this.handleBrowseToURL, - notifyPanicRecover: this.handleGoPanic, - }) - }) - } - - handleIPNState = (state: IPNState) => { - const { ipn } = this.state - this.setState({ ipnState: state }) - if (state === "NeedsLogin") { - ipn?.login() - } else if (["Running", "NeedsMachineAuth"].includes(state)) { - this.setState({ browseToURL: undefined }) - } - } - - handleNetMap = (netMapStr: string) => { - const netMap = JSON.parse(netMapStr) as IPNNetMap - if (DEBUG) { - console.log("Received net map: " + JSON.stringify(netMap, null, 2)) - } - this.setState({ netMap }) - } - - handleBrowseToURL = (url: string) => { - if (this.state.ipnState === "Running") { - // Ignore URL requests if we're already running -- it's most likely an - // SSH check mode trigger and we already linkify the displayed URL - // in the terminal. - return - } - this.setState({ browseToURL: url }) - } - - handleGoPanic = (error: string) => { - if (DEBUG) { - console.error("Go panic", error) - } - this.setState({ goPanicError: error }) - if (this.#goPanicTimeout) { - window.clearTimeout(this.#goPanicTimeout) - } - this.#goPanicTimeout = window.setTimeout(this.clearGoPanic, 10000) - } - - clearGoPanic = () => { - window.clearTimeout(this.#goPanicTimeout) - this.#goPanicTimeout = undefined - this.setState({ goPanicError: undefined }) - } -} - -export function renderApp(): Promise { - return new Promise((resolve) => { - render( - (app ? resolve(app) : undefined)} />, - document.body - ) - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +import { render, Component } from "preact" +import { URLDisplay } from "./url-display" +import { Header } from "./header" +import { GoPanicDisplay } from "./go-panic-display" +import { SSH } from "./ssh" + +type AppState = { + ipn?: IPN + ipnState: IPNState + netMap?: IPNNetMap + browseToURL?: string + goPanicError?: string +} + +class App extends Component<{}, AppState> { + state: AppState = { ipnState: "NoState" } + #goPanicTimeout?: number + + render() { + const { ipn, ipnState, goPanicError, netMap, browseToURL } = this.state + + let goPanicDisplay + if (goPanicError) { + goPanicDisplay = ( + + ) + } + + let urlDisplay + if (browseToURL) { + urlDisplay = + } + + let machineAuthInstructions + if (ipnState === "NeedsMachineAuth") { + machineAuthInstructions = ( +
+ An administrator needs to approve this device. +
+ ) + } + + const lockedOut = netMap?.lockedOut + let lockedOutInstructions + if (lockedOut) { + lockedOutInstructions = ( +
+

This instance of Tailscale Connect needs to be signed, due to + {" "}tailnet lock{" "} + being enabled on this domain. +

+ +

+ Run the following command on a device with a trusted tailnet lock key: +

tailscale lock sign {netMap.self.nodeKey}
+

+
+ ) + } + + let ssh + if (ipn && ipnState === "Running" && netMap && !lockedOut) { + ssh = + } + + return ( + <> +
+ {goPanicDisplay} +
+ {urlDisplay} + {machineAuthInstructions} + {lockedOutInstructions} + {ssh} +
+ + ) + } + + runWithIPN(ipn: IPN) { + this.setState({ ipn }, () => { + ipn.run({ + notifyState: this.handleIPNState, + notifyNetMap: this.handleNetMap, + notifyBrowseToURL: this.handleBrowseToURL, + notifyPanicRecover: this.handleGoPanic, + }) + }) + } + + handleIPNState = (state: IPNState) => { + const { ipn } = this.state + this.setState({ ipnState: state }) + if (state === "NeedsLogin") { + ipn?.login() + } else if (["Running", "NeedsMachineAuth"].includes(state)) { + this.setState({ browseToURL: undefined }) + } + } + + handleNetMap = (netMapStr: string) => { + const netMap = JSON.parse(netMapStr) as IPNNetMap + if (DEBUG) { + console.log("Received net map: " + JSON.stringify(netMap, null, 2)) + } + this.setState({ netMap }) + } + + handleBrowseToURL = (url: string) => { + if (this.state.ipnState === "Running") { + // Ignore URL requests if we're already running -- it's most likely an + // SSH check mode trigger and we already linkify the displayed URL + // in the terminal. + return + } + this.setState({ browseToURL: url }) + } + + handleGoPanic = (error: string) => { + if (DEBUG) { + console.error("Go panic", error) + } + this.setState({ goPanicError: error }) + if (this.#goPanicTimeout) { + window.clearTimeout(this.#goPanicTimeout) + } + this.#goPanicTimeout = window.setTimeout(this.clearGoPanic, 10000) + } + + clearGoPanic = () => { + window.clearTimeout(this.#goPanicTimeout) + this.#goPanicTimeout = undefined + this.setState({ goPanicError: undefined }) + } +} + +export function renderApp(): Promise { + return new Promise((resolve) => { + render( + (app ? resolve(app) : undefined)} />, + document.body + ) + }) +} diff --git a/cmd/tsconnect/src/app/go-panic-display.tsx b/cmd/tsconnect/src/app/go-panic-display.tsx index aab35c4d5..5dd7095a2 100644 --- a/cmd/tsconnect/src/app/go-panic-display.tsx +++ b/cmd/tsconnect/src/app/go-panic-display.tsx @@ -1,20 +1,20 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -export function GoPanicDisplay({ - error, - dismiss, -}: { - error: string - dismiss: () => void -}) { - return ( -
- Tailscale has encountered an error. -
Click to reload
-
- ) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +export function GoPanicDisplay({ + error, + dismiss, +}: { + error: string + dismiss: () => void +}) { + return ( +
+ Tailscale has encountered an error. +
Click to reload
+
+ ) +} diff --git a/cmd/tsconnect/src/app/header.tsx b/cmd/tsconnect/src/app/header.tsx index 8449f4563..099ff2f8c 100644 --- a/cmd/tsconnect/src/app/header.tsx +++ b/cmd/tsconnect/src/app/header.tsx @@ -1,37 +1,37 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -export function Header({ state, ipn }: { state: IPNState; ipn?: IPN }) { - const stateText = STATE_LABELS[state] - - let logoutButton - if (state === "Running") { - logoutButton = ( - - ) - } - return ( -
-
-

Tailscale Connect

-
{stateText}
- {logoutButton} -
-
- ) -} - -const STATE_LABELS = { - NoState: "Initializing…", - InUseOtherUser: "In-use by another user", - NeedsLogin: "Needs login", - NeedsMachineAuth: "Needs approval", - Stopped: "Stopped", - Starting: "Starting…", - Running: "Running", -} as const +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +export function Header({ state, ipn }: { state: IPNState; ipn?: IPN }) { + const stateText = STATE_LABELS[state] + + let logoutButton + if (state === "Running") { + logoutButton = ( + + ) + } + return ( +
+
+

Tailscale Connect

+
{stateText}
+ {logoutButton} +
+
+ ) +} + +const STATE_LABELS = { + NoState: "Initializing…", + InUseOtherUser: "In-use by another user", + NeedsLogin: "Needs login", + NeedsMachineAuth: "Needs approval", + Stopped: "Stopped", + Starting: "Starting…", + Running: "Running", +} as const diff --git a/cmd/tsconnect/src/app/index.css b/cmd/tsconnect/src/app/index.css index 848b83d12..751b313d9 100644 --- a/cmd/tsconnect/src/app/index.css +++ b/cmd/tsconnect/src/app/index.css @@ -1,74 +1,74 @@ -/* Copyright (c) Tailscale Inc & AUTHORS */ -/* SPDX-License-Identifier: BSD-3-Clause */ - -@import "xterm/css/xterm.css"; - -@tailwind base; -@tailwind components; -@tailwind utilities; - -.link { - @apply text-blue-600; -} - -.link:hover { - @apply underline; -} - -.button { - @apply font-medium py-1 px-2 rounded-md border border-transparent text-center cursor-pointer; - transition-property: background-color, border-color, color, box-shadow; - transition-duration: 120ms; - box-shadow: 0 1px 1px rgba(0, 0, 0, 0.04); - min-width: 80px; -} -.button:focus { - @apply outline-none ring; -} -.button:disabled { - @apply pointer-events-none select-none; -} - -.input { - @apply appearance-none leading-tight rounded-md bg-white border border-gray-300 hover:border-gray-400 transition-colors px-3; - height: 2.375rem; -} - -.input::placeholder { - @apply text-gray-400; -} - -.input:disabled { - @apply border-gray-200; - @apply bg-gray-50; - @apply cursor-not-allowed; -} - -.input:focus { - @apply outline-none ring border-transparent; -} - -.select { - @apply appearance-none py-2 px-3 leading-tight rounded-md bg-white border border-gray-300; -} - -.select-with-arrow { - @apply relative; -} - -.select-with-arrow .select { - width: 100%; -} - -.select-with-arrow::after { - @apply absolute; - content: ""; - top: 50%; - right: 0.5rem; - transform: translate(-0.3em, -0.15em); - width: 0.6em; - height: 0.4em; - opacity: 0.6; - background-color: currentColor; - clip-path: polygon(100% 0%, 0 0%, 50% 100%); -} +/* Copyright (c) Tailscale Inc & AUTHORS */ +/* SPDX-License-Identifier: BSD-3-Clause */ + +@import "xterm/css/xterm.css"; + +@tailwind base; +@tailwind components; +@tailwind utilities; + +.link { + @apply text-blue-600; +} + +.link:hover { + @apply underline; +} + +.button { + @apply font-medium py-1 px-2 rounded-md border border-transparent text-center cursor-pointer; + transition-property: background-color, border-color, color, box-shadow; + transition-duration: 120ms; + box-shadow: 0 1px 1px rgba(0, 0, 0, 0.04); + min-width: 80px; +} +.button:focus { + @apply outline-none ring; +} +.button:disabled { + @apply pointer-events-none select-none; +} + +.input { + @apply appearance-none leading-tight rounded-md bg-white border border-gray-300 hover:border-gray-400 transition-colors px-3; + height: 2.375rem; +} + +.input::placeholder { + @apply text-gray-400; +} + +.input:disabled { + @apply border-gray-200; + @apply bg-gray-50; + @apply cursor-not-allowed; +} + +.input:focus { + @apply outline-none ring border-transparent; +} + +.select { + @apply appearance-none py-2 px-3 leading-tight rounded-md bg-white border border-gray-300; +} + +.select-with-arrow { + @apply relative; +} + +.select-with-arrow .select { + width: 100%; +} + +.select-with-arrow::after { + @apply absolute; + content: ""; + top: 50%; + right: 0.5rem; + transform: translate(-0.3em, -0.15em); + width: 0.6em; + height: 0.4em; + opacity: 0.6; + background-color: currentColor; + clip-path: polygon(100% 0%, 0 0%, 50% 100%); +} diff --git a/cmd/tsconnect/src/app/index.ts b/cmd/tsconnect/src/app/index.ts index 1432188ae..24ca45439 100644 --- a/cmd/tsconnect/src/app/index.ts +++ b/cmd/tsconnect/src/app/index.ts @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -import "../wasm_exec" -import wasmUrl from "./main.wasm" -import { sessionStateStorage } from "../lib/js-state-store" -import { renderApp } from "./app" - -async function main() { - const app = await renderApp() - const go = new Go() - const wasmInstance = await WebAssembly.instantiateStreaming( - fetch(`./dist/${wasmUrl}`), - go.importObject - ) - // The Go process should never exit, if it does then it's an unhandled panic. - go.run(wasmInstance.instance).then(() => - app.handleGoPanic("Unexpected shutdown") - ) - - const params = new URLSearchParams(window.location.search) - const authKey = params.get("authkey") ?? undefined - - const ipn = newIPN({ - // Persist IPN state in sessionStorage in development, so that we don't need - // to re-authorize every time we reload the page. - stateStorage: DEBUG ? sessionStateStorage : undefined, - // authKey allows for an auth key to be - // specified as a url param which automatically - // authorizes the client for use. - authKey: DEBUG ? authKey : undefined, - }) - app.runWithIPN(ipn) -} - -main() +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +import "../wasm_exec" +import wasmUrl from "./main.wasm" +import { sessionStateStorage } from "../lib/js-state-store" +import { renderApp } from "./app" + +async function main() { + const app = await renderApp() + const go = new Go() + const wasmInstance = await WebAssembly.instantiateStreaming( + fetch(`./dist/${wasmUrl}`), + go.importObject + ) + // The Go process should never exit, if it does then it's an unhandled panic. + go.run(wasmInstance.instance).then(() => + app.handleGoPanic("Unexpected shutdown") + ) + + const params = new URLSearchParams(window.location.search) + const authKey = params.get("authkey") ?? undefined + + const ipn = newIPN({ + // Persist IPN state in sessionStorage in development, so that we don't need + // to re-authorize every time we reload the page. + stateStorage: DEBUG ? sessionStateStorage : undefined, + // authKey allows for an auth key to be + // specified as a url param which automatically + // authorizes the client for use. + authKey: DEBUG ? authKey : undefined, + }) + app.runWithIPN(ipn) +} + +main() diff --git a/cmd/tsconnect/src/app/ssh.tsx b/cmd/tsconnect/src/app/ssh.tsx index 1534fd5db..df81745bd 100644 --- a/cmd/tsconnect/src/app/ssh.tsx +++ b/cmd/tsconnect/src/app/ssh.tsx @@ -1,157 +1,157 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -import { useState, useCallback, useMemo, useEffect, useRef } from "preact/hooks" -import { createPortal } from "preact/compat" -import type { VNode } from "preact" -import { runSSHSession, SSHSessionDef } from "../lib/ssh" - -export function SSH({ netMap, ipn }: { netMap: IPNNetMap; ipn: IPN }) { - const [sshSessionDef, setSSHSessionDef] = useState( - null - ) - const clearSSHSessionDef = useCallback(() => setSSHSessionDef(null), []) - if (sshSessionDef) { - const sshSession = ( - - ) - if (sshSessionDef.newWindow) { - return {sshSession} - } - return sshSession - } - const sshPeers = netMap.peers.filter( - (p) => p.tailscaleSSHEnabled && p.online !== false - ) - - if (sshPeers.length == 0) { - return - } - - return -} - -type SSHFormSessionDef = SSHSessionDef & { newWindow?: boolean } - -function SSHSession({ - def, - ipn, - onDone, -}: { - def: SSHSessionDef - ipn: IPN - onDone: () => void -}) { - const ref = useRef(null) - useEffect(() => { - if (ref.current) { - runSSHSession(ref.current, def, ipn, { - onConnectionProgress: (p) => console.log("Connection progress", p), - onConnected() {}, - onError: (err) => console.error(err), - onDone, - }) - } - }, [ref]) - - return
-} - -function NoSSHPeers() { - return ( -
- None of your machines have{" "} - - Tailscale SSH - - {" "}enabled. Give it a try! -
- ) -} - -function SSHForm({ - sshPeers, - onSubmit, -}: { - sshPeers: IPNNetMapPeerNode[] - onSubmit: (def: SSHFormSessionDef) => void -}) { - sshPeers = sshPeers.slice().sort((a, b) => a.name.localeCompare(b.name)) - const [username, setUsername] = useState("") - const [hostname, setHostname] = useState(sshPeers[0].name) - return ( -
{ - e.preventDefault() - onSubmit({ username, hostname }) - }} - > - setUsername(e.currentTarget.value)} - /> -
- -
- { - if (e.altKey) { - e.preventDefault() - e.stopPropagation() - onSubmit({ username, hostname, newWindow: true }) - } - }} - /> -
- ) -} - -const NewWindow = ({ - children, - close, -}: { - children: VNode - close: () => void -}) => { - const newWindow = useMemo(() => { - const newWindow = window.open(undefined, undefined, "width=600,height=400") - if (newWindow) { - const containerNode = newWindow.document.createElement("div") - containerNode.className = "h-screen flex flex-col overflow-hidden" - newWindow.document.body.appendChild(containerNode) - - for (const linkNode of document.querySelectorAll( - "head link[rel=stylesheet]" - )) { - const newLink = document.createElement("link") - newLink.rel = "stylesheet" - newLink.href = (linkNode as HTMLLinkElement).href - newWindow.document.head.appendChild(newLink) - } - } - return newWindow - }, []) - if (!newWindow) { - console.error("Could not open window") - return null - } - newWindow.onbeforeunload = () => { - close() - } - - useEffect(() => () => newWindow.close(), []) - return createPortal(children, newWindow.document.body.lastChild as Element) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +import { useState, useCallback, useMemo, useEffect, useRef } from "preact/hooks" +import { createPortal } from "preact/compat" +import type { VNode } from "preact" +import { runSSHSession, SSHSessionDef } from "../lib/ssh" + +export function SSH({ netMap, ipn }: { netMap: IPNNetMap; ipn: IPN }) { + const [sshSessionDef, setSSHSessionDef] = useState( + null + ) + const clearSSHSessionDef = useCallback(() => setSSHSessionDef(null), []) + if (sshSessionDef) { + const sshSession = ( + + ) + if (sshSessionDef.newWindow) { + return {sshSession} + } + return sshSession + } + const sshPeers = netMap.peers.filter( + (p) => p.tailscaleSSHEnabled && p.online !== false + ) + + if (sshPeers.length == 0) { + return + } + + return +} + +type SSHFormSessionDef = SSHSessionDef & { newWindow?: boolean } + +function SSHSession({ + def, + ipn, + onDone, +}: { + def: SSHSessionDef + ipn: IPN + onDone: () => void +}) { + const ref = useRef(null) + useEffect(() => { + if (ref.current) { + runSSHSession(ref.current, def, ipn, { + onConnectionProgress: (p) => console.log("Connection progress", p), + onConnected() {}, + onError: (err) => console.error(err), + onDone, + }) + } + }, [ref]) + + return
+} + +function NoSSHPeers() { + return ( +
+ None of your machines have{" "} + + Tailscale SSH + + {" "}enabled. Give it a try! +
+ ) +} + +function SSHForm({ + sshPeers, + onSubmit, +}: { + sshPeers: IPNNetMapPeerNode[] + onSubmit: (def: SSHFormSessionDef) => void +}) { + sshPeers = sshPeers.slice().sort((a, b) => a.name.localeCompare(b.name)) + const [username, setUsername] = useState("") + const [hostname, setHostname] = useState(sshPeers[0].name) + return ( +
{ + e.preventDefault() + onSubmit({ username, hostname }) + }} + > + setUsername(e.currentTarget.value)} + /> +
+ +
+ { + if (e.altKey) { + e.preventDefault() + e.stopPropagation() + onSubmit({ username, hostname, newWindow: true }) + } + }} + /> +
+ ) +} + +const NewWindow = ({ + children, + close, +}: { + children: VNode + close: () => void +}) => { + const newWindow = useMemo(() => { + const newWindow = window.open(undefined, undefined, "width=600,height=400") + if (newWindow) { + const containerNode = newWindow.document.createElement("div") + containerNode.className = "h-screen flex flex-col overflow-hidden" + newWindow.document.body.appendChild(containerNode) + + for (const linkNode of document.querySelectorAll( + "head link[rel=stylesheet]" + )) { + const newLink = document.createElement("link") + newLink.rel = "stylesheet" + newLink.href = (linkNode as HTMLLinkElement).href + newWindow.document.head.appendChild(newLink) + } + } + return newWindow + }, []) + if (!newWindow) { + console.error("Could not open window") + return null + } + newWindow.onbeforeunload = () => { + close() + } + + useEffect(() => () => newWindow.close(), []) + return createPortal(children, newWindow.document.body.lastChild as Element) +} diff --git a/cmd/tsconnect/src/app/url-display.tsx b/cmd/tsconnect/src/app/url-display.tsx index c9b590181..fc82c7fb9 100644 --- a/cmd/tsconnect/src/app/url-display.tsx +++ b/cmd/tsconnect/src/app/url-display.tsx @@ -1,31 +1,31 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -import { useState } from "preact/hooks" -import * as qrcode from "qrcode" - -export function URLDisplay({ url }: { url: string }) { - const [dataURL, setDataURL] = useState("") - qrcode.toDataURL(url, { width: 512 }, (err, dataURL) => { - if (err) { - console.error("Error generating QR code", err) - } else { - setDataURL(dataURL) - } - }) - - return ( - - ) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +import { useState } from "preact/hooks" +import * as qrcode from "qrcode" + +export function URLDisplay({ url }: { url: string }) { + const [dataURL, setDataURL] = useState("") + qrcode.toDataURL(url, { width: 512 }, (err, dataURL) => { + if (err) { + console.error("Error generating QR code", err) + } else { + setDataURL(dataURL) + } + }) + + return ( + + ) +} diff --git a/cmd/tsconnect/src/lib/js-state-store.ts b/cmd/tsconnect/src/lib/js-state-store.ts index 7685e28a9..e57dfd98e 100644 --- a/cmd/tsconnect/src/lib/js-state-store.ts +++ b/cmd/tsconnect/src/lib/js-state-store.ts @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -/** @fileoverview Callbacks used by jsStateStore to persist IPN state. */ - -export const sessionStateStorage: IPNStateStorage = { - setState(id, value) { - window.sessionStorage[`ipn-state-${id}`] = value - }, - getState(id) { - return window.sessionStorage[`ipn-state-${id}`] || "" - }, -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +/** @fileoverview Callbacks used by jsStateStore to persist IPN state. */ + +export const sessionStateStorage: IPNStateStorage = { + setState(id, value) { + window.sessionStorage[`ipn-state-${id}`] = value + }, + getState(id) { + return window.sessionStorage[`ipn-state-${id}`] || "" + }, +} diff --git a/cmd/tsconnect/src/pkg/pkg.css b/cmd/tsconnect/src/pkg/pkg.css index 60146d5b7..76ea21f5b 100644 --- a/cmd/tsconnect/src/pkg/pkg.css +++ b/cmd/tsconnect/src/pkg/pkg.css @@ -1,8 +1,8 @@ -/* Copyright (c) Tailscale Inc & AUTHORS */ -/* SPDX-License-Identifier: BSD-3-Clause */ - -@import "xterm/css/xterm.css"; - -@tailwind base; -@tailwind components; -@tailwind utilities; +/* Copyright (c) Tailscale Inc & AUTHORS */ +/* SPDX-License-Identifier: BSD-3-Clause */ + +@import "xterm/css/xterm.css"; + +@tailwind base; +@tailwind components; +@tailwind utilities; diff --git a/cmd/tsconnect/src/pkg/pkg.ts b/cmd/tsconnect/src/pkg/pkg.ts index c0dcb5652..4d535cb40 100644 --- a/cmd/tsconnect/src/pkg/pkg.ts +++ b/cmd/tsconnect/src/pkg/pkg.ts @@ -1,40 +1,40 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Type definitions need to be manually imported for dts-bundle-generator to -// discover them. -/// -/// - -import "../wasm_exec" -import wasmURL from "./main.wasm" - -/** - * Superset of the IPNConfig type, with additional configuration that is - * needed for the package to function. - */ -type IPNPackageConfig = IPNConfig & { - // Auth key used to initialize the Tailscale client (required) - authKey: string - // URL of the main.wasm file that is included in the page, if it is not - // accessible via a relative URL. - wasmURL?: string - // Function invoked if the Go process panics or unexpectedly exits. - panicHandler: (err: string) => void -} - -export async function createIPN(config: IPNPackageConfig): Promise { - const go = new Go() - const wasmInstance = await WebAssembly.instantiateStreaming( - fetch(config.wasmURL ?? wasmURL), - go.importObject - ) - // The Go process should never exit, if it does then it's an unhandled panic. - go.run(wasmInstance.instance).then(() => - config.panicHandler("Unexpected shutdown") - ) - - return newIPN(config) -} - -export { runSSHSession } from "../lib/ssh" +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Type definitions need to be manually imported for dts-bundle-generator to +// discover them. +/// +/// + +import "../wasm_exec" +import wasmURL from "./main.wasm" + +/** + * Superset of the IPNConfig type, with additional configuration that is + * needed for the package to function. + */ +type IPNPackageConfig = IPNConfig & { + // Auth key used to initialize the Tailscale client (required) + authKey: string + // URL of the main.wasm file that is included in the page, if it is not + // accessible via a relative URL. + wasmURL?: string + // Function invoked if the Go process panics or unexpectedly exits. + panicHandler: (err: string) => void +} + +export async function createIPN(config: IPNPackageConfig): Promise { + const go = new Go() + const wasmInstance = await WebAssembly.instantiateStreaming( + fetch(config.wasmURL ?? wasmURL), + go.importObject + ) + // The Go process should never exit, if it does then it's an unhandled panic. + go.run(wasmInstance.instance).then(() => + config.panicHandler("Unexpected shutdown") + ) + + return newIPN(config) +} + +export { runSSHSession } from "../lib/ssh" diff --git a/cmd/tsconnect/src/types/esbuild.d.ts b/cmd/tsconnect/src/types/esbuild.d.ts index 7153b4244..ef28f7b1c 100644 --- a/cmd/tsconnect/src/types/esbuild.d.ts +++ b/cmd/tsconnect/src/types/esbuild.d.ts @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -/** - * @fileoverview Type definitions for types generated by the esbuild build - * process. - */ - -declare module "*.wasm" { - const path: string - export default path -} - -declare const DEBUG: boolean +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +/** + * @fileoverview Type definitions for types generated by the esbuild build + * process. + */ + +declare module "*.wasm" { + const path: string + export default path +} + +declare const DEBUG: boolean diff --git a/cmd/tsconnect/src/types/wasm_js.d.ts b/cmd/tsconnect/src/types/wasm_js.d.ts index 82822c508..492197ccb 100644 --- a/cmd/tsconnect/src/types/wasm_js.d.ts +++ b/cmd/tsconnect/src/types/wasm_js.d.ts @@ -1,103 +1,103 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -/** - * @fileoverview Type definitions for types exported by the wasm_js.go Go - * module. - */ - -declare global { - function newIPN(config: IPNConfig): IPN - - interface IPN { - run(callbacks: IPNCallbacks): void - login(): void - logout(): void - ssh( - host: string, - username: string, - termConfig: { - writeFn: (data: string) => void - writeErrorFn: (err: string) => void - setReadFn: (readFn: (data: string) => void) => void - rows: number - cols: number - /** Defaults to 5 seconds */ - timeoutSeconds?: number - onConnectionProgress: (message: string) => void - onConnected: () => void - onDone: () => void - } - ): IPNSSHSession - fetch(url: string): Promise<{ - status: number - statusText: string - text: () => Promise - }> - } - - interface IPNSSHSession { - resize(rows: number, cols: number): boolean - close(): boolean - } - - interface IPNStateStorage { - setState(id: string, value: string): void - getState(id: string): string - } - - type IPNConfig = { - stateStorage?: IPNStateStorage - authKey?: string - controlURL?: string - hostname?: string - } - - type IPNCallbacks = { - notifyState: (state: IPNState) => void - notifyNetMap: (netMapStr: string) => void - notifyBrowseToURL: (url: string) => void - notifyPanicRecover: (err: string) => void - } - - type IPNNetMap = { - self: IPNNetMapSelfNode - peers: IPNNetMapPeerNode[] - lockedOut: boolean - } - - type IPNNetMapNode = { - name: string - addresses: string[] - machineKey: string - nodeKey: string - } - - type IPNNetMapSelfNode = IPNNetMapNode & { - machineStatus: IPNMachineStatus - } - - type IPNNetMapPeerNode = IPNNetMapNode & { - online?: boolean - tailscaleSSHEnabled: boolean - } - - /** Mirrors values from ipn/backend.go */ - type IPNState = - | "NoState" - | "InUseOtherUser" - | "NeedsLogin" - | "NeedsMachineAuth" - | "Stopped" - | "Starting" - | "Running" - - /** Mirrors values from MachineStatus in tailcfg.go */ - type IPNMachineStatus = - | "MachineUnknown" - | "MachineUnauthorized" - | "MachineAuthorized" - | "MachineInvalid" -} - -export {} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +/** + * @fileoverview Type definitions for types exported by the wasm_js.go Go + * module. + */ + +declare global { + function newIPN(config: IPNConfig): IPN + + interface IPN { + run(callbacks: IPNCallbacks): void + login(): void + logout(): void + ssh( + host: string, + username: string, + termConfig: { + writeFn: (data: string) => void + writeErrorFn: (err: string) => void + setReadFn: (readFn: (data: string) => void) => void + rows: number + cols: number + /** Defaults to 5 seconds */ + timeoutSeconds?: number + onConnectionProgress: (message: string) => void + onConnected: () => void + onDone: () => void + } + ): IPNSSHSession + fetch(url: string): Promise<{ + status: number + statusText: string + text: () => Promise + }> + } + + interface IPNSSHSession { + resize(rows: number, cols: number): boolean + close(): boolean + } + + interface IPNStateStorage { + setState(id: string, value: string): void + getState(id: string): string + } + + type IPNConfig = { + stateStorage?: IPNStateStorage + authKey?: string + controlURL?: string + hostname?: string + } + + type IPNCallbacks = { + notifyState: (state: IPNState) => void + notifyNetMap: (netMapStr: string) => void + notifyBrowseToURL: (url: string) => void + notifyPanicRecover: (err: string) => void + } + + type IPNNetMap = { + self: IPNNetMapSelfNode + peers: IPNNetMapPeerNode[] + lockedOut: boolean + } + + type IPNNetMapNode = { + name: string + addresses: string[] + machineKey: string + nodeKey: string + } + + type IPNNetMapSelfNode = IPNNetMapNode & { + machineStatus: IPNMachineStatus + } + + type IPNNetMapPeerNode = IPNNetMapNode & { + online?: boolean + tailscaleSSHEnabled: boolean + } + + /** Mirrors values from ipn/backend.go */ + type IPNState = + | "NoState" + | "InUseOtherUser" + | "NeedsLogin" + | "NeedsMachineAuth" + | "Stopped" + | "Starting" + | "Running" + + /** Mirrors values from MachineStatus in tailcfg.go */ + type IPNMachineStatus = + | "MachineUnknown" + | "MachineUnauthorized" + | "MachineAuthorized" + | "MachineInvalid" +} + +export {} diff --git a/cmd/tsconnect/tailwind.config.js b/cmd/tsconnect/tailwind.config.js index 38bc5b97b..31823000b 100644 --- a/cmd/tsconnect/tailwind.config.js +++ b/cmd/tsconnect/tailwind.config.js @@ -1,8 +1,8 @@ -/** @type {import('tailwindcss').Config} */ -module.exports = { - content: ["./index.html", "./src/**/*.ts", "./src/**/*.tsx"], - theme: { - extend: {}, - }, - plugins: [], -} +/** @type {import('tailwindcss').Config} */ +module.exports = { + content: ["./index.html", "./src/**/*.ts", "./src/**/*.tsx"], + theme: { + extend: {}, + }, + plugins: [], +} diff --git a/cmd/tsconnect/tsconfig.json b/cmd/tsconnect/tsconfig.json index 1148e2ef0..52c25c727 100644 --- a/cmd/tsconnect/tsconfig.json +++ b/cmd/tsconnect/tsconfig.json @@ -1,15 +1,15 @@ -{ - "compilerOptions": { - "target": "ES2017", - "module": "ES2020", - "moduleResolution": "node", - "isolatedModules": true, - "strict": true, - "forceConsistentCasingInFileNames": true, - "sourceMap": true, - "jsx": "react-jsx", - "jsxImportSource": "preact" - }, - "include": ["src/**/*"], - "exclude": ["node_modules"] -} +{ + "compilerOptions": { + "target": "ES2017", + "module": "ES2020", + "moduleResolution": "node", + "isolatedModules": true, + "strict": true, + "forceConsistentCasingInFileNames": true, + "sourceMap": true, + "jsx": "react-jsx", + "jsxImportSource": "preact" + }, + "include": ["src/**/*"], + "exclude": ["node_modules"] +} diff --git a/cmd/tsconnect/tsconnect.go b/cmd/tsconnect/tsconnect.go index 60ea6ef82..4c8a0a52e 100644 --- a/cmd/tsconnect/tsconnect.go +++ b/cmd/tsconnect/tsconnect.go @@ -1,71 +1,71 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -// The tsconnect command builds and serves the static site that is generated for -// the Tailscale Connect JS/WASM client. Can be run in 3 modes: -// - dev: builds the site and serves it. JS and CSS changes can be picked up -// with a reload. -// - build: builds the site and writes it to dist/ -// - serve: serves the site from dist/ (embedded in the binary) -package main // import "tailscale.com/cmd/tsconnect" - -import ( - "flag" - "fmt" - "log" - "os" -) - -var ( - addr = flag.String("addr", ":9090", "address to listen on") - distDir = flag.String("distdir", "./dist", "path of directory to place build output in") - pkgDir = flag.String("pkgdir", "./pkg", "path of directory to place NPM package build output in") - yarnPath = flag.String("yarnpath", "", "path yarn executable used to install JavaScript dependencies") - fastCompression = flag.Bool("fast-compression", false, "Use faster compression when building, to speed up build time. Meant to iterative/debugging use only.") - devControl = flag.String("dev-control", "", "URL of a development control server to be used with dev. If provided without specifying dev, an error will be returned.") - rootDir = flag.String("rootdir", "", "Root directory of repo. If not specified, will be inferred from the cwd.") -) - -func main() { - flag.Usage = usage - flag.Parse() - if len(flag.Args()) != 1 { - flag.Usage() - } - - switch flag.Arg(0) { - case "dev": - runDev() - case "dev-pkg": - runDevPkg() - case "build": - runBuild() - case "build-pkg": - runBuildPkg() - case "serve": - runServe() - default: - log.Printf("Unknown command: %s", flag.Arg(0)) - flag.Usage() - } -} - -func usage() { - fmt.Fprintf(os.Stderr, ` -usage: tsconnect {dev|build|serve} -`[1:]) - - flag.PrintDefaults() - fmt.Fprintf(os.Stderr, ` - -tsconnect implements development/build/serving workflows for Tailscale Connect. -It can be invoked with one of three subcommands: - -- dev: Run in development mode, allowing JS and CSS changes to be picked up without a rebuilt or restart. -- build: Run in production build mode (generating static assets) -- serve: Run in production serve mode (serving static assets) -`[1:]) - os.Exit(2) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// The tsconnect command builds and serves the static site that is generated for +// the Tailscale Connect JS/WASM client. Can be run in 3 modes: +// - dev: builds the site and serves it. JS and CSS changes can be picked up +// with a reload. +// - build: builds the site and writes it to dist/ +// - serve: serves the site from dist/ (embedded in the binary) +package main // import "tailscale.com/cmd/tsconnect" + +import ( + "flag" + "fmt" + "log" + "os" +) + +var ( + addr = flag.String("addr", ":9090", "address to listen on") + distDir = flag.String("distdir", "./dist", "path of directory to place build output in") + pkgDir = flag.String("pkgdir", "./pkg", "path of directory to place NPM package build output in") + yarnPath = flag.String("yarnpath", "", "path yarn executable used to install JavaScript dependencies") + fastCompression = flag.Bool("fast-compression", false, "Use faster compression when building, to speed up build time. Meant to iterative/debugging use only.") + devControl = flag.String("dev-control", "", "URL of a development control server to be used with dev. If provided without specifying dev, an error will be returned.") + rootDir = flag.String("rootdir", "", "Root directory of repo. If not specified, will be inferred from the cwd.") +) + +func main() { + flag.Usage = usage + flag.Parse() + if len(flag.Args()) != 1 { + flag.Usage() + } + + switch flag.Arg(0) { + case "dev": + runDev() + case "dev-pkg": + runDevPkg() + case "build": + runBuild() + case "build-pkg": + runBuildPkg() + case "serve": + runServe() + default: + log.Printf("Unknown command: %s", flag.Arg(0)) + flag.Usage() + } +} + +func usage() { + fmt.Fprintf(os.Stderr, ` +usage: tsconnect {dev|build|serve} +`[1:]) + + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, ` + +tsconnect implements development/build/serving workflows for Tailscale Connect. +It can be invoked with one of three subcommands: + +- dev: Run in development mode, allowing JS and CSS changes to be picked up without a rebuilt or restart. +- build: Run in production build mode (generating static assets) +- serve: Run in production serve mode (serving static assets) +`[1:]) + os.Exit(2) +} diff --git a/cmd/tsconnect/yarn.lock b/cmd/tsconnect/yarn.lock index 914b4e6d0..663a1244e 100644 --- a/cmd/tsconnect/yarn.lock +++ b/cmd/tsconnect/yarn.lock @@ -1,713 +1,713 @@ -# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY. -# yarn lockfile v1 - - -"@nodelib/fs.scandir@2.1.5": - version "2.1.5" - resolved "https://registry.yarnpkg.com/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz#7619c2eb21b25483f6d167548b4cfd5a7488c3d5" - integrity sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g== - dependencies: - "@nodelib/fs.stat" "2.0.5" - run-parallel "^1.1.9" - -"@nodelib/fs.stat@2.0.5", "@nodelib/fs.stat@^2.0.2": - version "2.0.5" - resolved "https://registry.yarnpkg.com/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz#5bd262af94e9d25bd1e71b05deed44876a222e8b" - integrity sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A== - -"@nodelib/fs.walk@^1.2.3": - version "1.2.8" - resolved "https://registry.yarnpkg.com/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz#e95737e8bb6746ddedf69c556953494f196fe69a" - integrity sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg== - dependencies: - "@nodelib/fs.scandir" "2.1.5" - fastq "^1.6.0" - -"@types/golang-wasm-exec@^1.15.0": - version "1.15.0" - resolved "https://registry.yarnpkg.com/@types/golang-wasm-exec/-/golang-wasm-exec-1.15.0.tgz#d0aafbb2b0dc07eaf45dfb83bfb6cdd5b2b3c55c" - integrity sha512-FrL97mp7WW8LqNinVkzTVKOIQKuYjQqgucnh41+1vRQ+bf1LT8uh++KRf9otZPXsa6H1p8ruIGz1BmCGttOL6Q== - -"@types/node@*": - version "18.6.1" - resolved "https://registry.yarnpkg.com/@types/node/-/node-18.6.1.tgz#828e4785ccca13f44e2fb6852ae0ef11e3e20ba5" - integrity sha512-z+2vB6yDt1fNwKOeGbckpmirO+VBDuQqecXkgeIqDlaOtmKn6hPR/viQ8cxCfqLU4fTlvM3+YjM367TukWdxpg== - -"@types/qrcode@^1.4.2": - version "1.4.2" - resolved "https://registry.yarnpkg.com/@types/qrcode/-/qrcode-1.4.2.tgz#7d7142d6fa9921f195db342ed08b539181546c74" - integrity sha512-7uNT9L4WQTNJejHTSTdaJhfBSCN73xtXaHFyBJ8TSwiLhe4PRuTue7Iph0s2nG9R/ifUaSnGhLUOZavlBEqDWQ== - dependencies: - "@types/node" "*" - -acorn-node@^1.8.2: - version "1.8.2" - resolved "https://registry.yarnpkg.com/acorn-node/-/acorn-node-1.8.2.tgz#114c95d64539e53dede23de8b9d96df7c7ae2af8" - integrity sha512-8mt+fslDufLYntIoPAaIMUe/lrbrehIiwmR3t2k9LljIzoigEPF27eLk2hy8zSGzmR/ogr7zbRKINMo1u0yh5A== - dependencies: - acorn "^7.0.0" - acorn-walk "^7.0.0" - xtend "^4.0.2" - -acorn-walk@^7.0.0: - version "7.2.0" - resolved "https://registry.yarnpkg.com/acorn-walk/-/acorn-walk-7.2.0.tgz#0de889a601203909b0fbe07b8938dc21d2e967bc" - integrity sha512-OPdCF6GsMIP+Az+aWfAAOEt2/+iVDKE7oy6lJ098aoe59oAmK76qV6Gw60SbZ8jHuG2wH058GF4pLFbYamYrVA== - -acorn@^7.0.0: - version "7.4.1" - resolved "https://registry.yarnpkg.com/acorn/-/acorn-7.4.1.tgz#feaed255973d2e77555b83dbc08851a6c63520fa" - integrity sha512-nQyp0o1/mNdbTO1PO6kHkwSrmgZ0MT/jCCpNiwbUjGoRN4dlBhqJtoQuCnEOKzgTVwg0ZWiCoQy6SxMebQVh8A== - -ansi-regex@^5.0.1: - version "5.0.1" - resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-5.0.1.tgz#082cb2c89c9fe8659a311a53bd6a4dc5301db304" - integrity sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ== - -ansi-styles@^4.0.0: - version "4.3.0" - resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.3.0.tgz#edd803628ae71c04c85ae7a0906edad34b648937" - integrity sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg== - dependencies: - color-convert "^2.0.1" - -anymatch@~3.1.2: - version "3.1.2" - resolved "https://registry.yarnpkg.com/anymatch/-/anymatch-3.1.2.tgz#c0557c096af32f106198f4f4e2a383537e378716" - integrity sha512-P43ePfOAIupkguHUycrc4qJ9kz8ZiuOUijaETwX7THt0Y/GNK7v0aa8rY816xWjZ7rJdA5XdMcpVFTKMq+RvWg== - dependencies: - normalize-path "^3.0.0" - picomatch "^2.0.4" - -arg@^5.0.2: - version "5.0.2" - resolved "https://registry.yarnpkg.com/arg/-/arg-5.0.2.tgz#c81433cc427c92c4dcf4865142dbca6f15acd59c" - integrity sha512-PYjyFOLKQ9y57JvQ6QLo8dAgNqswh8M1RMJYdQduT6xbWSgK36P/Z/v+p888pM69jMMfS8Xd8F6I1kQ/I9HUGg== - -binary-extensions@^2.0.0: - version "2.2.0" - resolved "https://registry.yarnpkg.com/binary-extensions/-/binary-extensions-2.2.0.tgz#75f502eeaf9ffde42fc98829645be4ea76bd9e2d" - integrity sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA== - -braces@^3.0.2, braces@~3.0.2: - version "3.0.2" - resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.2.tgz#3454e1a462ee8d599e236df336cd9ea4f8afe107" - integrity sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A== - dependencies: - fill-range "^7.0.1" - -camelcase-css@^2.0.1: - version "2.0.1" - resolved "https://registry.yarnpkg.com/camelcase-css/-/camelcase-css-2.0.1.tgz#ee978f6947914cc30c6b44741b6ed1df7f043fd5" - integrity sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA== - -camelcase@^5.0.0: - version "5.3.1" - resolved "https://registry.yarnpkg.com/camelcase/-/camelcase-5.3.1.tgz#e3c9b31569e106811df242f715725a1f4c494320" - integrity sha512-L28STB170nwWS63UjtlEOE3dldQApaJXZkOI1uMFfzf3rRuPegHaHesyee+YxQ+W6SvRDQV6UrdOdRiR153wJg== - -chokidar@^3.5.3: - version "3.5.3" - resolved "https://registry.yarnpkg.com/chokidar/-/chokidar-3.5.3.tgz#1cf37c8707b932bd1af1ae22c0432e2acd1903bd" - integrity sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw== - dependencies: - anymatch "~3.1.2" - braces "~3.0.2" - glob-parent "~5.1.2" - is-binary-path "~2.1.0" - is-glob "~4.0.1" - normalize-path "~3.0.0" - readdirp "~3.6.0" - optionalDependencies: - fsevents "~2.3.2" - -cliui@^6.0.0: - version "6.0.0" - resolved "https://registry.yarnpkg.com/cliui/-/cliui-6.0.0.tgz#511d702c0c4e41ca156d7d0e96021f23e13225b1" - integrity sha512-t6wbgtoCXvAzst7QgXxJYqPt0usEfbgQdftEPbLL/cvv6HPE5VgvqCuAIDR0NgU52ds6rFwqrgakNLrHEjCbrQ== - dependencies: - string-width "^4.2.0" - strip-ansi "^6.0.0" - wrap-ansi "^6.2.0" - -cliui@^7.0.2: - version "7.0.4" - resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.4.tgz#a0265ee655476fc807aea9df3df8df7783808b4f" - integrity sha512-OcRE68cOsVMXp1Yvonl/fzkQOyjLSu/8bhPDfQt0e0/Eb283TKP20Fs2MqoPsr9SwA595rRCA+QMzYc9nBP+JQ== - dependencies: - string-width "^4.2.0" - strip-ansi "^6.0.0" - wrap-ansi "^7.0.0" - -color-convert@^2.0.1: - version "2.0.1" - resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3" - integrity sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ== - dependencies: - color-name "~1.1.4" - -color-name@^1.1.4, color-name@~1.1.4: - version "1.1.4" - resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.4.tgz#c2a09a87acbde69543de6f63fa3995c826c536a2" - integrity sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA== - -cssesc@^3.0.0: - version "3.0.0" - resolved "https://registry.yarnpkg.com/cssesc/-/cssesc-3.0.0.tgz#37741919903b868565e1c09ea747445cd18983ee" - integrity sha512-/Tb/JcjK111nNScGob5MNtsntNM1aCNUDipB/TkwZFhyDrrE47SOx/18wF2bbjgc3ZzCSKW1T5nt5EbFoAz/Vg== - -decamelize@^1.2.0: - version "1.2.0" - resolved "https://registry.yarnpkg.com/decamelize/-/decamelize-1.2.0.tgz#f6534d15148269b20352e7bee26f501f9a191290" - integrity sha1-9lNNFRSCabIDUue+4m9QH5oZEpA= - -defined@^1.0.0: - version "1.0.0" - resolved "https://registry.yarnpkg.com/defined/-/defined-1.0.0.tgz#c98d9bcef75674188e110969151199e39b1fa693" - integrity sha512-Y2caI5+ZwS5c3RiNDJ6u53VhQHv+hHKwhkI1iHvceKUHw9Df6EK2zRLfjejRgMuCuxK7PfSWIMwWecceVvThjQ== - -detective@^5.2.1: - version "5.2.1" - resolved "https://registry.yarnpkg.com/detective/-/detective-5.2.1.tgz#6af01eeda11015acb0e73f933242b70f24f91034" - integrity sha512-v9XE1zRnz1wRtgurGu0Bs8uHKFSTdteYZNbIPFVhUZ39L/S79ppMpdmVOZAnoz1jfEFodc48n6MX483Xo3t1yw== - dependencies: - acorn-node "^1.8.2" - defined "^1.0.0" - minimist "^1.2.6" - -didyoumean@^1.2.2: - version "1.2.2" - resolved "https://registry.yarnpkg.com/didyoumean/-/didyoumean-1.2.2.tgz#989346ffe9e839b4555ecf5666edea0d3e8ad037" - integrity sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw== - -dijkstrajs@^1.0.1: - version "1.0.2" - resolved "https://registry.yarnpkg.com/dijkstrajs/-/dijkstrajs-1.0.2.tgz#2e48c0d3b825462afe75ab4ad5e829c8ece36257" - integrity sha512-QV6PMaHTCNmKSeP6QoXhVTw9snc9VD8MulTT0Bd99Pacp4SS1cjcrYPgBPmibqKVtMJJfqC6XvOXgPMEEPH/fg== - -dlv@^1.1.3: - version "1.1.3" - resolved "https://registry.yarnpkg.com/dlv/-/dlv-1.1.3.tgz#5c198a8a11453596e751494d49874bc7732f2e79" - integrity sha512-+HlytyjlPKnIG8XuRG8WvmBP8xs8P71y+SKKS6ZXWoEgLuePxtDoUEiH7WkdePWrQ5JBpE6aoVqfZfJUQkjXwA== - -dts-bundle-generator@^6.12.0: - version "6.12.0" - resolved "https://registry.yarnpkg.com/dts-bundle-generator/-/dts-bundle-generator-6.12.0.tgz#0a221bdce5fdd309a56c8556e645f16ed87ab07d" - integrity sha512-k/QAvuVaLIdyWRUHduDrWBe4j8PcE6TDt06+f32KHbW7/SmUPbX1O23fFtQgKwUyTBkbIjJFOFtNrF97tJcKug== - dependencies: - typescript ">=3.0.1" - yargs "^17.2.1" - -emoji-regex@^8.0.0: - version "8.0.0" - resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-8.0.0.tgz#e818fd69ce5ccfcb404594f842963bf53164cc37" - integrity sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A== - -encode-utf8@^1.0.3: - version "1.0.3" - resolved "https://registry.yarnpkg.com/encode-utf8/-/encode-utf8-1.0.3.tgz#f30fdd31da07fb596f281beb2f6b027851994cda" - integrity sha512-ucAnuBEhUK4boH2HjVYG5Q2mQyPorvv0u/ocS+zhdw0S8AlHYY+GOFhP1Gio5z4icpP2ivFSvhtFjQi8+T9ppw== - -escalade@^3.1.1: - version "3.1.1" - resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40" - integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw== - -fast-glob@^3.2.11: - version "3.2.11" - resolved "https://registry.yarnpkg.com/fast-glob/-/fast-glob-3.2.11.tgz#a1172ad95ceb8a16e20caa5c5e56480e5129c1d9" - integrity sha512-xrO3+1bxSo3ZVHAnqzyuewYT6aMFHRAd4Kcs92MAonjwQZLsK9d0SF1IyQ3k5PoirxTW0Oe/RqFgMQ6TcNE5Ew== - dependencies: - "@nodelib/fs.stat" "^2.0.2" - "@nodelib/fs.walk" "^1.2.3" - glob-parent "^5.1.2" - merge2 "^1.3.0" - micromatch "^4.0.4" - -fastq@^1.6.0: - version "1.13.0" - resolved "https://registry.yarnpkg.com/fastq/-/fastq-1.13.0.tgz#616760f88a7526bdfc596b7cab8c18938c36b98c" - integrity sha512-YpkpUnK8od0o1hmeSc7UUs/eB/vIPWJYjKck2QKIzAf71Vm1AAQ3EbuZB3g2JIy+pg+ERD0vqI79KyZiB2e2Nw== - dependencies: - reusify "^1.0.4" - -fill-range@^7.0.1: - version "7.0.1" - resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.0.1.tgz#1919a6a7c75fe38b2c7c77e5198535da9acdda40" - integrity sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ== - dependencies: - to-regex-range "^5.0.1" - -find-up@^4.1.0: - version "4.1.0" - resolved "https://registry.yarnpkg.com/find-up/-/find-up-4.1.0.tgz#97afe7d6cdc0bc5928584b7c8d7b16e8a9aa5d19" - integrity sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw== - dependencies: - locate-path "^5.0.0" - path-exists "^4.0.0" - -fsevents@~2.3.2: - version "2.3.2" - resolved "https://registry.yarnpkg.com/fsevents/-/fsevents-2.3.2.tgz#8a526f78b8fdf4623b709e0b975c52c24c02fd1a" - integrity sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA== - -function-bind@^1.1.1: - version "1.1.1" - resolved "https://registry.yarnpkg.com/function-bind/-/function-bind-1.1.1.tgz#a56899d3ea3c9bab874bb9773b7c5ede92f4895d" - integrity sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A== - -get-caller-file@^2.0.1, get-caller-file@^2.0.5: - version "2.0.5" - resolved "https://registry.yarnpkg.com/get-caller-file/-/get-caller-file-2.0.5.tgz#4f94412a82db32f36e3b0b9741f8a97feb031f7e" - integrity sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg== - -glob-parent@^5.1.2, glob-parent@~5.1.2: - version "5.1.2" - resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-5.1.2.tgz#869832c58034fe68a4093c17dc15e8340d8401c4" - integrity sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow== - dependencies: - is-glob "^4.0.1" - -glob-parent@^6.0.2: - version "6.0.2" - resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-6.0.2.tgz#6d237d99083950c79290f24c7642a3de9a28f9e3" - integrity sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A== - dependencies: - is-glob "^4.0.3" - -has@^1.0.3: - version "1.0.3" - resolved "https://registry.yarnpkg.com/has/-/has-1.0.3.tgz#722d7cbfc1f6aa8241f16dd814e011e1f41e8796" - integrity sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw== - dependencies: - function-bind "^1.1.1" - -is-binary-path@~2.1.0: - version "2.1.0" - resolved "https://registry.yarnpkg.com/is-binary-path/-/is-binary-path-2.1.0.tgz#ea1f7f3b80f064236e83470f86c09c254fb45b09" - integrity sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw== - dependencies: - binary-extensions "^2.0.0" - -is-core-module@^2.9.0: - version "2.9.0" - resolved "https://registry.yarnpkg.com/is-core-module/-/is-core-module-2.9.0.tgz#e1c34429cd51c6dd9e09e0799e396e27b19a9c69" - integrity sha512-+5FPy5PnwmO3lvfMb0AsoPaBG+5KHUI0wYFXOtYPnVVVspTFUuMZNfNaNVRt3FZadstu2c8x23vykRW/NBoU6A== - dependencies: - has "^1.0.3" - -is-extglob@^2.1.1: - version "2.1.1" - resolved "https://registry.yarnpkg.com/is-extglob/-/is-extglob-2.1.1.tgz#a88c02535791f02ed37c76a1b9ea9773c833f8c2" - integrity sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ== - -is-fullwidth-code-point@^3.0.0: - version "3.0.0" - resolved "https://registry.yarnpkg.com/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz#f116f8064fe90b3f7844a38997c0b75051269f1d" - integrity sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg== - -is-glob@^4.0.1, is-glob@^4.0.3, is-glob@~4.0.1: - version "4.0.3" - resolved "https://registry.yarnpkg.com/is-glob/-/is-glob-4.0.3.tgz#64f61e42cbbb2eec2071a9dac0b28ba1e65d5084" - integrity sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg== - dependencies: - is-extglob "^2.1.1" - -is-number@^7.0.0: - version "7.0.0" - resolved "https://registry.yarnpkg.com/is-number/-/is-number-7.0.0.tgz#7535345b896734d5f80c4d06c50955527a14f12b" - integrity sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng== - -lilconfig@^2.0.5: - version "2.0.6" - resolved "https://registry.yarnpkg.com/lilconfig/-/lilconfig-2.0.6.tgz#32a384558bd58af3d4c6e077dd1ad1d397bc69d4" - integrity sha512-9JROoBW7pobfsx+Sq2JsASvCo6Pfo6WWoUW79HuB1BCoBXD4PLWJPqDF6fNj67pqBYTbAHkE57M1kS/+L1neOg== - -locate-path@^5.0.0: - version "5.0.0" - resolved "https://registry.yarnpkg.com/locate-path/-/locate-path-5.0.0.tgz#1afba396afd676a6d42504d0a67a3a7eb9f62aa0" - integrity sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g== - dependencies: - p-locate "^4.1.0" - -merge2@^1.3.0: - version "1.4.1" - resolved "https://registry.yarnpkg.com/merge2/-/merge2-1.4.1.tgz#4368892f885e907455a6fd7dc55c0c9d404990ae" - integrity sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg== - -micromatch@^4.0.4: - version "4.0.5" - resolved "https://registry.yarnpkg.com/micromatch/-/micromatch-4.0.5.tgz#bc8999a7cbbf77cdc89f132f6e467051b49090c6" - integrity sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA== - dependencies: - braces "^3.0.2" - picomatch "^2.3.1" - -minimist@^1.2.6: - version "1.2.6" - resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.6.tgz#8637a5b759ea0d6e98702cfb3a9283323c93af44" - integrity sha512-Jsjnk4bw3YJqYzbdyBiNsPWHPfO++UGG749Cxs6peCu5Xg4nrena6OVxOYxrQTqww0Jmwt+Ref8rggumkTLz9Q== - -nanoid@^3.3.4: - version "3.3.4" - resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.4.tgz#730b67e3cd09e2deacf03c027c81c9d9dbc5e8ab" - integrity sha512-MqBkQh/OHTS2egovRtLk45wEyNXwF+cokD+1YPf9u5VfJiRdAiRwB2froX5Co9Rh20xs4siNPm8naNotSD6RBw== - -normalize-path@^3.0.0, normalize-path@~3.0.0: - version "3.0.0" - resolved "https://registry.yarnpkg.com/normalize-path/-/normalize-path-3.0.0.tgz#0dcd69ff23a1c9b11fd0978316644a0388216a65" - integrity sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA== - -object-hash@^3.0.0: - version "3.0.0" - resolved "https://registry.yarnpkg.com/object-hash/-/object-hash-3.0.0.tgz#73f97f753e7baffc0e2cc9d6e079079744ac82e9" - integrity sha512-RSn9F68PjH9HqtltsSnqYC1XXoWe9Bju5+213R98cNGttag9q9yAOTzdbsqvIa7aNm5WffBZFpWYr2aWrklWAw== - -p-limit@^2.2.0: - version "2.3.0" - resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-2.3.0.tgz#3dd33c647a214fdfffd835933eb086da0dc21db1" - integrity sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w== - dependencies: - p-try "^2.0.0" - -p-locate@^4.1.0: - version "4.1.0" - resolved "https://registry.yarnpkg.com/p-locate/-/p-locate-4.1.0.tgz#a3428bb7088b3a60292f66919278b7c297ad4f07" - integrity sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A== - dependencies: - p-limit "^2.2.0" - -p-try@^2.0.0: - version "2.2.0" - resolved "https://registry.yarnpkg.com/p-try/-/p-try-2.2.0.tgz#cb2868540e313d61de58fafbe35ce9004d5540e6" - integrity sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ== - -path-exists@^4.0.0: - version "4.0.0" - resolved "https://registry.yarnpkg.com/path-exists/-/path-exists-4.0.0.tgz#513bdbe2d3b95d7762e8c1137efa195c6c61b5b3" - integrity sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w== - -path-parse@^1.0.7: - version "1.0.7" - resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.7.tgz#fbc114b60ca42b30d9daf5858e4bd68bbedb6735" - integrity sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw== - -picocolors@^1.0.0: - version "1.0.0" - resolved "https://registry.yarnpkg.com/picocolors/-/picocolors-1.0.0.tgz#cb5bdc74ff3f51892236eaf79d68bc44564ab81c" - integrity sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ== - -picomatch@^2.0.4, picomatch@^2.2.1, picomatch@^2.3.1: - version "2.3.1" - resolved "https://registry.yarnpkg.com/picomatch/-/picomatch-2.3.1.tgz#3ba3833733646d9d3e4995946c1365a67fb07a42" - integrity sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA== - -pify@^2.3.0: - version "2.3.0" - resolved "https://registry.yarnpkg.com/pify/-/pify-2.3.0.tgz#ed141a6ac043a849ea588498e7dca8b15330e90c" - integrity sha512-udgsAY+fTnvv7kI7aaxbqwWNb0AHiB0qBO89PZKPkoTmGOgdbrHDKD+0B2X4uTfJ/FT1R09r9gTsjUjNJotuog== - -pngjs@^5.0.0: - version "5.0.0" - resolved "https://registry.yarnpkg.com/pngjs/-/pngjs-5.0.0.tgz#e79dd2b215767fd9c04561c01236df960bce7fbb" - integrity sha512-40QW5YalBNfQo5yRYmiw7Yz6TKKVr3h6970B2YE+3fQpsWcrbj1PzJgxeJ19DRQjhMbKPIuMY8rFaXc8moolVw== - -postcss-import@^14.1.0: - version "14.1.0" - resolved "https://registry.yarnpkg.com/postcss-import/-/postcss-import-14.1.0.tgz#a7333ffe32f0b8795303ee9e40215dac922781f0" - integrity sha512-flwI+Vgm4SElObFVPpTIT7SU7R3qk2L7PyduMcokiaVKuWv9d/U+Gm/QAd8NDLuykTWTkcrjOeD2Pp1rMeBTGw== - dependencies: - postcss-value-parser "^4.0.0" - read-cache "^1.0.0" - resolve "^1.1.7" - -postcss-js@^4.0.0: - version "4.0.0" - resolved "https://registry.yarnpkg.com/postcss-js/-/postcss-js-4.0.0.tgz#31db79889531b80dc7bc9b0ad283e418dce0ac00" - integrity sha512-77QESFBwgX4irogGVPgQ5s07vLvFqWr228qZY+w6lW599cRlK/HmnlivnnVUxkjHnCu4J16PDMHcH+e+2HbvTQ== - dependencies: - camelcase-css "^2.0.1" - -postcss-load-config@^3.1.4: - version "3.1.4" - resolved "https://registry.yarnpkg.com/postcss-load-config/-/postcss-load-config-3.1.4.tgz#1ab2571faf84bb078877e1d07905eabe9ebda855" - integrity sha512-6DiM4E7v4coTE4uzA8U//WhtPwyhiim3eyjEMFCnUpzbrkK9wJHgKDT2mR+HbtSrd/NubVaYTOpSpjUl8NQeRg== - dependencies: - lilconfig "^2.0.5" - yaml "^1.10.2" - -postcss-nested@5.0.6: - version "5.0.6" - resolved "https://registry.yarnpkg.com/postcss-nested/-/postcss-nested-5.0.6.tgz#466343f7fc8d3d46af3e7dba3fcd47d052a945bc" - integrity sha512-rKqm2Fk0KbA8Vt3AdGN0FB9OBOMDVajMG6ZCf/GoHgdxUJ4sBFp0A/uMIRm+MJUdo33YXEtjqIz8u7DAp8B7DA== - dependencies: - postcss-selector-parser "^6.0.6" - -postcss-selector-parser@^6.0.10, postcss-selector-parser@^6.0.6: - version "6.0.10" - resolved "https://registry.yarnpkg.com/postcss-selector-parser/-/postcss-selector-parser-6.0.10.tgz#79b61e2c0d1bfc2602d549e11d0876256f8df88d" - integrity sha512-IQ7TZdoaqbT+LCpShg46jnZVlhWD2w6iQYAcYXfHARZ7X1t/UGhhceQDs5X0cGqKvYlHNOuv7Oa1xmb0oQuA3w== - dependencies: - cssesc "^3.0.0" - util-deprecate "^1.0.2" - -postcss-value-parser@^4.0.0, postcss-value-parser@^4.2.0: - version "4.2.0" - resolved "https://registry.yarnpkg.com/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz#723c09920836ba6d3e5af019f92bc0971c02e514" - integrity sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ== - -postcss@^8.4.14: - version "8.4.14" - resolved "https://registry.yarnpkg.com/postcss/-/postcss-8.4.14.tgz#ee9274d5622b4858c1007a74d76e42e56fd21caf" - integrity sha512-E398TUmfAYFPBSdzgeieK2Y1+1cpdxJx8yXbK/m57nRhKSmk1GB2tO4lbLBtlkfPQTDKfe4Xqv1ASWPpayPEig== - dependencies: - nanoid "^3.3.4" - picocolors "^1.0.0" - source-map-js "^1.0.2" - -preact@^10.10.0: - version "10.10.0" - resolved "https://registry.yarnpkg.com/preact/-/preact-10.10.0.tgz#7434750a24b59dae1957d95dc0aa47a4a8e9a180" - integrity sha512-fszkg1iJJjq68I4lI8ZsmBiaoQiQHbxf1lNq+72EmC/mZOsFF5zn3k1yv9QGoFgIXzgsdSKtYymLJsrJPoamjQ== - -qrcode@^1.5.0: - version "1.5.0" - resolved "https://registry.yarnpkg.com/qrcode/-/qrcode-1.5.0.tgz#95abb8a91fdafd86f8190f2836abbfc500c72d1b" - integrity sha512-9MgRpgVc+/+47dFvQeD6U2s0Z92EsKzcHogtum4QB+UNd025WOJSHvn/hjk9xmzj7Stj95CyUAs31mrjxliEsQ== - dependencies: - dijkstrajs "^1.0.1" - encode-utf8 "^1.0.3" - pngjs "^5.0.0" - yargs "^15.3.1" - -queue-microtask@^1.2.2: - version "1.2.3" - resolved "https://registry.yarnpkg.com/queue-microtask/-/queue-microtask-1.2.3.tgz#4929228bbc724dfac43e0efb058caf7b6cfb6243" - integrity sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A== - -quick-lru@^5.1.1: - version "5.1.1" - resolved "https://registry.yarnpkg.com/quick-lru/-/quick-lru-5.1.1.tgz#366493e6b3e42a3a6885e2e99d18f80fb7a8c932" - integrity sha512-WuyALRjWPDGtt/wzJiadO5AXY+8hZ80hVpe6MyivgraREW751X3SbhRvG3eLKOYN+8VEvqLcf3wdnt44Z4S4SA== - -read-cache@^1.0.0: - version "1.0.0" - resolved "https://registry.yarnpkg.com/read-cache/-/read-cache-1.0.0.tgz#e664ef31161166c9751cdbe8dbcf86b5fb58f774" - integrity sha512-Owdv/Ft7IjOgm/i0xvNDZ1LrRANRfew4b2prF3OWMQLxLfu3bS8FVhCsrSCMK4lR56Y9ya+AThoTpDCTxCmpRA== - dependencies: - pify "^2.3.0" - -readdirp@~3.6.0: - version "3.6.0" - resolved "https://registry.yarnpkg.com/readdirp/-/readdirp-3.6.0.tgz#74a370bd857116e245b29cc97340cd431a02a6c7" - integrity sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA== - dependencies: - picomatch "^2.2.1" - -require-directory@^2.1.1: - version "2.1.1" - resolved "https://registry.yarnpkg.com/require-directory/-/require-directory-2.1.1.tgz#8c64ad5fd30dab1c976e2344ffe7f792a6a6df42" - integrity sha1-jGStX9MNqxyXbiNE/+f3kqam30I= - -require-main-filename@^2.0.0: - version "2.0.0" - resolved "https://registry.yarnpkg.com/require-main-filename/-/require-main-filename-2.0.0.tgz#d0b329ecc7cc0f61649f62215be69af54aa8989b" - integrity sha512-NKN5kMDylKuldxYLSUfrbo5Tuzh4hd+2E8NPPX02mZtn1VuREQToYe/ZdlJy+J3uCpfaiGF05e7B8W0iXbQHmg== - -resolve@^1.1.7, resolve@^1.22.1: - version "1.22.1" - resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.22.1.tgz#27cb2ebb53f91abb49470a928bba7558066ac177" - integrity sha512-nBpuuYuY5jFsli/JIs1oldw6fOQCBioohqWZg/2hiaOybXOft4lonv85uDOKXdf8rhyK159cxU5cDcK/NKk8zw== - dependencies: - is-core-module "^2.9.0" - path-parse "^1.0.7" - supports-preserve-symlinks-flag "^1.0.0" - -reusify@^1.0.4: - version "1.0.4" - resolved "https://registry.yarnpkg.com/reusify/-/reusify-1.0.4.tgz#90da382b1e126efc02146e90845a88db12925d76" - integrity sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw== - -run-parallel@^1.1.9: - version "1.2.0" - resolved "https://registry.yarnpkg.com/run-parallel/-/run-parallel-1.2.0.tgz#66d1368da7bdf921eb9d95bd1a9229e7f21a43ee" - integrity sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA== - dependencies: - queue-microtask "^1.2.2" - -set-blocking@^2.0.0: - version "2.0.0" - resolved "https://registry.yarnpkg.com/set-blocking/-/set-blocking-2.0.0.tgz#045f9782d011ae9a6803ddd382b24392b3d890f7" - integrity sha1-BF+XgtARrppoA93TgrJDkrPYkPc= - -source-map-js@^1.0.2: - version "1.0.2" - resolved "https://registry.yarnpkg.com/source-map-js/-/source-map-js-1.0.2.tgz#adbc361d9c62df380125e7f161f71c826f1e490c" - integrity sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw== - -string-width@^4.1.0, string-width@^4.2.0, string-width@^4.2.3: - version "4.2.3" - resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" - integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== - dependencies: - emoji-regex "^8.0.0" - is-fullwidth-code-point "^3.0.0" - strip-ansi "^6.0.1" - -strip-ansi@^6.0.0, strip-ansi@^6.0.1: - version "6.0.1" - resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9" - integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== - dependencies: - ansi-regex "^5.0.1" - -supports-preserve-symlinks-flag@^1.0.0: - version "1.0.0" - resolved "https://registry.yarnpkg.com/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz#6eda4bd344a3c94aea376d4cc31bc77311039e09" - integrity sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w== - -tailwindcss@^3.1.6: - version "3.1.6" - resolved "https://registry.yarnpkg.com/tailwindcss/-/tailwindcss-3.1.6.tgz#bcb719357776c39e6376a8d84e9834b2b19a49f1" - integrity sha512-7skAOY56erZAFQssT1xkpk+kWt2NrO45kORlxFPXUt3CiGsVPhH1smuH5XoDH6sGPXLyBv+zgCKA2HWBsgCytg== - dependencies: - arg "^5.0.2" - chokidar "^3.5.3" - color-name "^1.1.4" - detective "^5.2.1" - didyoumean "^1.2.2" - dlv "^1.1.3" - fast-glob "^3.2.11" - glob-parent "^6.0.2" - is-glob "^4.0.3" - lilconfig "^2.0.5" - normalize-path "^3.0.0" - object-hash "^3.0.0" - picocolors "^1.0.0" - postcss "^8.4.14" - postcss-import "^14.1.0" - postcss-js "^4.0.0" - postcss-load-config "^3.1.4" - postcss-nested "5.0.6" - postcss-selector-parser "^6.0.10" - postcss-value-parser "^4.2.0" - quick-lru "^5.1.1" - resolve "^1.22.1" - -to-regex-range@^5.0.1: - version "5.0.1" - resolved "https://registry.yarnpkg.com/to-regex-range/-/to-regex-range-5.0.1.tgz#1648c44aae7c8d988a326018ed72f5b4dd0392e4" - integrity sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ== - dependencies: - is-number "^7.0.0" - -typescript@>=3.0.1, typescript@^4.7.4: - version "4.7.4" - resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.7.4.tgz#1a88596d1cf47d59507a1bcdfb5b9dfe4d488235" - integrity sha512-C0WQT0gezHuw6AdY1M2jxUO83Rjf0HP7Sk1DtXj6j1EwkQNZrHAg2XPWlq62oqEhYvONq5pkC2Y9oPljWToLmQ== - -util-deprecate@^1.0.2: - version "1.0.2" - resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" - integrity sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw== - -which-module@^2.0.0: - version "2.0.0" - resolved "https://registry.yarnpkg.com/which-module/-/which-module-2.0.0.tgz#d9ef07dce77b9902b8a3a8fa4b31c3e3f7e6e87a" - integrity sha1-2e8H3Od7mQK4o6j6SzHD4/fm6Ho= - -wrap-ansi@^6.2.0: - version "6.2.0" - resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-6.2.0.tgz#e9393ba07102e6c91a3b221478f0257cd2856e53" - integrity sha512-r6lPcBGxZXlIcymEu7InxDMhdW0KDxpLgoFLcguasxCaJ/SOIZwINatK9KY/tf+ZrlywOKU0UDj3ATXUBfxJXA== - dependencies: - ansi-styles "^4.0.0" - string-width "^4.1.0" - strip-ansi "^6.0.0" - -wrap-ansi@^7.0.0: - version "7.0.0" - resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43" - integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q== - dependencies: - ansi-styles "^4.0.0" - string-width "^4.1.0" - strip-ansi "^6.0.0" - -xtend@^4.0.2: - version "4.0.2" - resolved "https://registry.yarnpkg.com/xtend/-/xtend-4.0.2.tgz#bb72779f5fa465186b1f438f674fa347fdb5db54" - integrity sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ== - -xterm-addon-fit@^0.7.0: - version "0.7.0" - resolved "https://registry.yarnpkg.com/xterm-addon-fit/-/xterm-addon-fit-0.7.0.tgz#b8ade6d96e63b47443862088f6670b49fb752c6a" - integrity sha512-tQgHGoHqRTgeROPnvmtEJywLKoC/V9eNs4bLLz7iyJr1aW/QFzRwfd3MGiJ6odJd9xEfxcW36/xRU47JkD5NKQ== - -xterm-addon-web-links@^0.8.0: - version "0.8.0" - resolved "https://registry.yarnpkg.com/xterm-addon-web-links/-/xterm-addon-web-links-0.8.0.tgz#2cb1d57129271022569208578b0bf4774e7e6ea9" - integrity sha512-J4tKngmIu20ytX9SEJjAP3UGksah7iALqBtfTwT9ZnmFHVplCumYQsUJfKuS+JwMhjsjH61YXfndenLNvjRrEw== - -xterm@^5.1.0: - version "5.1.0" - resolved "https://registry.yarnpkg.com/xterm/-/xterm-5.1.0.tgz#3e160d60e6801c864b55adf19171c49d2ff2b4fc" - integrity sha512-LovENH4WDzpwynj+OTkLyZgJPeDom9Gra4DMlGAgz6pZhIDCQ+YuO7yfwanY+gVbn/mmZIStNOnVRU/ikQuAEQ== - -y18n@^4.0.0: - version "4.0.3" - resolved "https://registry.yarnpkg.com/y18n/-/y18n-4.0.3.tgz#b5f259c82cd6e336921efd7bfd8bf560de9eeedf" - integrity sha512-JKhqTOwSrqNA1NY5lSztJ1GrBiUodLMmIZuLiDaMRJ+itFd+ABVE8XBjOvIWL+rSqNDC74LCSFmlb/U4UZ4hJQ== - -y18n@^5.0.5: - version "5.0.8" - resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.8.tgz#7f4934d0f7ca8c56f95314939ddcd2dd91ce1d55" - integrity sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA== - -yaml@^1.10.2: - version "1.10.2" - resolved "https://registry.yarnpkg.com/yaml/-/yaml-1.10.2.tgz#2301c5ffbf12b467de8da2333a459e29e7920e4b" - integrity sha512-r3vXyErRCYJ7wg28yvBY5VSoAF8ZvlcW9/BwUzEtUsjvX/DKs24dIkuwjtuprwJJHsbyUbLApepYTR1BN4uHrg== - -yargs-parser@^18.1.2: - version "18.1.3" - resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-18.1.3.tgz#be68c4975c6b2abf469236b0c870362fab09a7b0" - integrity sha512-o50j0JeToy/4K6OZcaQmW6lyXXKhq7csREXcDwk2omFPJEwUNOVtJKvmDr9EI1fAJZUyZcRF7kxGBWmRXudrCQ== - dependencies: - camelcase "^5.0.0" - decamelize "^1.2.0" - -yargs-parser@^21.0.0: - version "21.1.1" - resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-21.1.1.tgz#9096bceebf990d21bb31fa9516e0ede294a77d35" - integrity sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw== - -yargs@^15.3.1: - version "15.4.1" - resolved "https://registry.yarnpkg.com/yargs/-/yargs-15.4.1.tgz#0d87a16de01aee9d8bec2bfbf74f67851730f4f8" - integrity sha512-aePbxDmcYW++PaqBsJ+HYUFwCdv4LVvdnhBy78E57PIor8/OVvhMrADFFEDh8DHDFRv/O9i3lPhsENjO7QX0+A== - dependencies: - cliui "^6.0.0" - decamelize "^1.2.0" - find-up "^4.1.0" - get-caller-file "^2.0.1" - require-directory "^2.1.1" - require-main-filename "^2.0.0" - set-blocking "^2.0.0" - string-width "^4.2.0" - which-module "^2.0.0" - y18n "^4.0.0" - yargs-parser "^18.1.2" - -yargs@^17.2.1: - version "17.5.1" - resolved "https://registry.yarnpkg.com/yargs/-/yargs-17.5.1.tgz#e109900cab6fcb7fd44b1d8249166feb0b36e58e" - integrity sha512-t6YAJcxDkNX7NFYiVtKvWUz8l+PaKTLiL63mJYWR2GnHq2gjEWISzsLp9wg3aY36dY1j+gfIEL3pIF+XlJJfbA== - dependencies: - cliui "^7.0.2" - escalade "^3.1.1" - get-caller-file "^2.0.5" - require-directory "^2.1.1" - string-width "^4.2.3" - y18n "^5.0.5" - yargs-parser "^21.0.0" +# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY. +# yarn lockfile v1 + + +"@nodelib/fs.scandir@2.1.5": + version "2.1.5" + resolved "https://registry.yarnpkg.com/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz#7619c2eb21b25483f6d167548b4cfd5a7488c3d5" + integrity sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g== + dependencies: + "@nodelib/fs.stat" "2.0.5" + run-parallel "^1.1.9" + +"@nodelib/fs.stat@2.0.5", "@nodelib/fs.stat@^2.0.2": + version "2.0.5" + resolved "https://registry.yarnpkg.com/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz#5bd262af94e9d25bd1e71b05deed44876a222e8b" + integrity sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A== + +"@nodelib/fs.walk@^1.2.3": + version "1.2.8" + resolved "https://registry.yarnpkg.com/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz#e95737e8bb6746ddedf69c556953494f196fe69a" + integrity sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg== + dependencies: + "@nodelib/fs.scandir" "2.1.5" + fastq "^1.6.0" + +"@types/golang-wasm-exec@^1.15.0": + version "1.15.0" + resolved "https://registry.yarnpkg.com/@types/golang-wasm-exec/-/golang-wasm-exec-1.15.0.tgz#d0aafbb2b0dc07eaf45dfb83bfb6cdd5b2b3c55c" + integrity sha512-FrL97mp7WW8LqNinVkzTVKOIQKuYjQqgucnh41+1vRQ+bf1LT8uh++KRf9otZPXsa6H1p8ruIGz1BmCGttOL6Q== + +"@types/node@*": + version "18.6.1" + resolved "https://registry.yarnpkg.com/@types/node/-/node-18.6.1.tgz#828e4785ccca13f44e2fb6852ae0ef11e3e20ba5" + integrity sha512-z+2vB6yDt1fNwKOeGbckpmirO+VBDuQqecXkgeIqDlaOtmKn6hPR/viQ8cxCfqLU4fTlvM3+YjM367TukWdxpg== + +"@types/qrcode@^1.4.2": + version "1.4.2" + resolved "https://registry.yarnpkg.com/@types/qrcode/-/qrcode-1.4.2.tgz#7d7142d6fa9921f195db342ed08b539181546c74" + integrity sha512-7uNT9L4WQTNJejHTSTdaJhfBSCN73xtXaHFyBJ8TSwiLhe4PRuTue7Iph0s2nG9R/ifUaSnGhLUOZavlBEqDWQ== + dependencies: + "@types/node" "*" + +acorn-node@^1.8.2: + version "1.8.2" + resolved "https://registry.yarnpkg.com/acorn-node/-/acorn-node-1.8.2.tgz#114c95d64539e53dede23de8b9d96df7c7ae2af8" + integrity sha512-8mt+fslDufLYntIoPAaIMUe/lrbrehIiwmR3t2k9LljIzoigEPF27eLk2hy8zSGzmR/ogr7zbRKINMo1u0yh5A== + dependencies: + acorn "^7.0.0" + acorn-walk "^7.0.0" + xtend "^4.0.2" + +acorn-walk@^7.0.0: + version "7.2.0" + resolved "https://registry.yarnpkg.com/acorn-walk/-/acorn-walk-7.2.0.tgz#0de889a601203909b0fbe07b8938dc21d2e967bc" + integrity sha512-OPdCF6GsMIP+Az+aWfAAOEt2/+iVDKE7oy6lJ098aoe59oAmK76qV6Gw60SbZ8jHuG2wH058GF4pLFbYamYrVA== + +acorn@^7.0.0: + version "7.4.1" + resolved "https://registry.yarnpkg.com/acorn/-/acorn-7.4.1.tgz#feaed255973d2e77555b83dbc08851a6c63520fa" + integrity sha512-nQyp0o1/mNdbTO1PO6kHkwSrmgZ0MT/jCCpNiwbUjGoRN4dlBhqJtoQuCnEOKzgTVwg0ZWiCoQy6SxMebQVh8A== + +ansi-regex@^5.0.1: + version "5.0.1" + resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-5.0.1.tgz#082cb2c89c9fe8659a311a53bd6a4dc5301db304" + integrity sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ== + +ansi-styles@^4.0.0: + version "4.3.0" + resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.3.0.tgz#edd803628ae71c04c85ae7a0906edad34b648937" + integrity sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg== + dependencies: + color-convert "^2.0.1" + +anymatch@~3.1.2: + version "3.1.2" + resolved "https://registry.yarnpkg.com/anymatch/-/anymatch-3.1.2.tgz#c0557c096af32f106198f4f4e2a383537e378716" + integrity sha512-P43ePfOAIupkguHUycrc4qJ9kz8ZiuOUijaETwX7THt0Y/GNK7v0aa8rY816xWjZ7rJdA5XdMcpVFTKMq+RvWg== + dependencies: + normalize-path "^3.0.0" + picomatch "^2.0.4" + +arg@^5.0.2: + version "5.0.2" + resolved "https://registry.yarnpkg.com/arg/-/arg-5.0.2.tgz#c81433cc427c92c4dcf4865142dbca6f15acd59c" + integrity sha512-PYjyFOLKQ9y57JvQ6QLo8dAgNqswh8M1RMJYdQduT6xbWSgK36P/Z/v+p888pM69jMMfS8Xd8F6I1kQ/I9HUGg== + +binary-extensions@^2.0.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/binary-extensions/-/binary-extensions-2.2.0.tgz#75f502eeaf9ffde42fc98829645be4ea76bd9e2d" + integrity sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA== + +braces@^3.0.2, braces@~3.0.2: + version "3.0.2" + resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.2.tgz#3454e1a462ee8d599e236df336cd9ea4f8afe107" + integrity sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A== + dependencies: + fill-range "^7.0.1" + +camelcase-css@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/camelcase-css/-/camelcase-css-2.0.1.tgz#ee978f6947914cc30c6b44741b6ed1df7f043fd5" + integrity sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA== + +camelcase@^5.0.0: + version "5.3.1" + resolved "https://registry.yarnpkg.com/camelcase/-/camelcase-5.3.1.tgz#e3c9b31569e106811df242f715725a1f4c494320" + integrity sha512-L28STB170nwWS63UjtlEOE3dldQApaJXZkOI1uMFfzf3rRuPegHaHesyee+YxQ+W6SvRDQV6UrdOdRiR153wJg== + +chokidar@^3.5.3: + version "3.5.3" + resolved "https://registry.yarnpkg.com/chokidar/-/chokidar-3.5.3.tgz#1cf37c8707b932bd1af1ae22c0432e2acd1903bd" + integrity sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw== + dependencies: + anymatch "~3.1.2" + braces "~3.0.2" + glob-parent "~5.1.2" + is-binary-path "~2.1.0" + is-glob "~4.0.1" + normalize-path "~3.0.0" + readdirp "~3.6.0" + optionalDependencies: + fsevents "~2.3.2" + +cliui@^6.0.0: + version "6.0.0" + resolved "https://registry.yarnpkg.com/cliui/-/cliui-6.0.0.tgz#511d702c0c4e41ca156d7d0e96021f23e13225b1" + integrity sha512-t6wbgtoCXvAzst7QgXxJYqPt0usEfbgQdftEPbLL/cvv6HPE5VgvqCuAIDR0NgU52ds6rFwqrgakNLrHEjCbrQ== + dependencies: + string-width "^4.2.0" + strip-ansi "^6.0.0" + wrap-ansi "^6.2.0" + +cliui@^7.0.2: + version "7.0.4" + resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.4.tgz#a0265ee655476fc807aea9df3df8df7783808b4f" + integrity sha512-OcRE68cOsVMXp1Yvonl/fzkQOyjLSu/8bhPDfQt0e0/Eb283TKP20Fs2MqoPsr9SwA595rRCA+QMzYc9nBP+JQ== + dependencies: + string-width "^4.2.0" + strip-ansi "^6.0.0" + wrap-ansi "^7.0.0" + +color-convert@^2.0.1: + version "2.0.1" + resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3" + integrity sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ== + dependencies: + color-name "~1.1.4" + +color-name@^1.1.4, color-name@~1.1.4: + version "1.1.4" + resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.4.tgz#c2a09a87acbde69543de6f63fa3995c826c536a2" + integrity sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA== + +cssesc@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/cssesc/-/cssesc-3.0.0.tgz#37741919903b868565e1c09ea747445cd18983ee" + integrity sha512-/Tb/JcjK111nNScGob5MNtsntNM1aCNUDipB/TkwZFhyDrrE47SOx/18wF2bbjgc3ZzCSKW1T5nt5EbFoAz/Vg== + +decamelize@^1.2.0: + version "1.2.0" + resolved "https://registry.yarnpkg.com/decamelize/-/decamelize-1.2.0.tgz#f6534d15148269b20352e7bee26f501f9a191290" + integrity sha1-9lNNFRSCabIDUue+4m9QH5oZEpA= + +defined@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/defined/-/defined-1.0.0.tgz#c98d9bcef75674188e110969151199e39b1fa693" + integrity sha512-Y2caI5+ZwS5c3RiNDJ6u53VhQHv+hHKwhkI1iHvceKUHw9Df6EK2zRLfjejRgMuCuxK7PfSWIMwWecceVvThjQ== + +detective@^5.2.1: + version "5.2.1" + resolved "https://registry.yarnpkg.com/detective/-/detective-5.2.1.tgz#6af01eeda11015acb0e73f933242b70f24f91034" + integrity sha512-v9XE1zRnz1wRtgurGu0Bs8uHKFSTdteYZNbIPFVhUZ39L/S79ppMpdmVOZAnoz1jfEFodc48n6MX483Xo3t1yw== + dependencies: + acorn-node "^1.8.2" + defined "^1.0.0" + minimist "^1.2.6" + +didyoumean@^1.2.2: + version "1.2.2" + resolved "https://registry.yarnpkg.com/didyoumean/-/didyoumean-1.2.2.tgz#989346ffe9e839b4555ecf5666edea0d3e8ad037" + integrity sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw== + +dijkstrajs@^1.0.1: + version "1.0.2" + resolved "https://registry.yarnpkg.com/dijkstrajs/-/dijkstrajs-1.0.2.tgz#2e48c0d3b825462afe75ab4ad5e829c8ece36257" + integrity sha512-QV6PMaHTCNmKSeP6QoXhVTw9snc9VD8MulTT0Bd99Pacp4SS1cjcrYPgBPmibqKVtMJJfqC6XvOXgPMEEPH/fg== + +dlv@^1.1.3: + version "1.1.3" + resolved "https://registry.yarnpkg.com/dlv/-/dlv-1.1.3.tgz#5c198a8a11453596e751494d49874bc7732f2e79" + integrity sha512-+HlytyjlPKnIG8XuRG8WvmBP8xs8P71y+SKKS6ZXWoEgLuePxtDoUEiH7WkdePWrQ5JBpE6aoVqfZfJUQkjXwA== + +dts-bundle-generator@^6.12.0: + version "6.12.0" + resolved "https://registry.yarnpkg.com/dts-bundle-generator/-/dts-bundle-generator-6.12.0.tgz#0a221bdce5fdd309a56c8556e645f16ed87ab07d" + integrity sha512-k/QAvuVaLIdyWRUHduDrWBe4j8PcE6TDt06+f32KHbW7/SmUPbX1O23fFtQgKwUyTBkbIjJFOFtNrF97tJcKug== + dependencies: + typescript ">=3.0.1" + yargs "^17.2.1" + +emoji-regex@^8.0.0: + version "8.0.0" + resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-8.0.0.tgz#e818fd69ce5ccfcb404594f842963bf53164cc37" + integrity sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A== + +encode-utf8@^1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/encode-utf8/-/encode-utf8-1.0.3.tgz#f30fdd31da07fb596f281beb2f6b027851994cda" + integrity sha512-ucAnuBEhUK4boH2HjVYG5Q2mQyPorvv0u/ocS+zhdw0S8AlHYY+GOFhP1Gio5z4icpP2ivFSvhtFjQi8+T9ppw== + +escalade@^3.1.1: + version "3.1.1" + resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40" + integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw== + +fast-glob@^3.2.11: + version "3.2.11" + resolved "https://registry.yarnpkg.com/fast-glob/-/fast-glob-3.2.11.tgz#a1172ad95ceb8a16e20caa5c5e56480e5129c1d9" + integrity sha512-xrO3+1bxSo3ZVHAnqzyuewYT6aMFHRAd4Kcs92MAonjwQZLsK9d0SF1IyQ3k5PoirxTW0Oe/RqFgMQ6TcNE5Ew== + dependencies: + "@nodelib/fs.stat" "^2.0.2" + "@nodelib/fs.walk" "^1.2.3" + glob-parent "^5.1.2" + merge2 "^1.3.0" + micromatch "^4.0.4" + +fastq@^1.6.0: + version "1.13.0" + resolved "https://registry.yarnpkg.com/fastq/-/fastq-1.13.0.tgz#616760f88a7526bdfc596b7cab8c18938c36b98c" + integrity sha512-YpkpUnK8od0o1hmeSc7UUs/eB/vIPWJYjKck2QKIzAf71Vm1AAQ3EbuZB3g2JIy+pg+ERD0vqI79KyZiB2e2Nw== + dependencies: + reusify "^1.0.4" + +fill-range@^7.0.1: + version "7.0.1" + resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.0.1.tgz#1919a6a7c75fe38b2c7c77e5198535da9acdda40" + integrity sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ== + dependencies: + to-regex-range "^5.0.1" + +find-up@^4.1.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/find-up/-/find-up-4.1.0.tgz#97afe7d6cdc0bc5928584b7c8d7b16e8a9aa5d19" + integrity sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw== + dependencies: + locate-path "^5.0.0" + path-exists "^4.0.0" + +fsevents@~2.3.2: + version "2.3.2" + resolved "https://registry.yarnpkg.com/fsevents/-/fsevents-2.3.2.tgz#8a526f78b8fdf4623b709e0b975c52c24c02fd1a" + integrity sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA== + +function-bind@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/function-bind/-/function-bind-1.1.1.tgz#a56899d3ea3c9bab874bb9773b7c5ede92f4895d" + integrity sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A== + +get-caller-file@^2.0.1, get-caller-file@^2.0.5: + version "2.0.5" + resolved "https://registry.yarnpkg.com/get-caller-file/-/get-caller-file-2.0.5.tgz#4f94412a82db32f36e3b0b9741f8a97feb031f7e" + integrity sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg== + +glob-parent@^5.1.2, glob-parent@~5.1.2: + version "5.1.2" + resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-5.1.2.tgz#869832c58034fe68a4093c17dc15e8340d8401c4" + integrity sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow== + dependencies: + is-glob "^4.0.1" + +glob-parent@^6.0.2: + version "6.0.2" + resolved "https://registry.yarnpkg.com/glob-parent/-/glob-parent-6.0.2.tgz#6d237d99083950c79290f24c7642a3de9a28f9e3" + integrity sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A== + dependencies: + is-glob "^4.0.3" + +has@^1.0.3: + version "1.0.3" + resolved "https://registry.yarnpkg.com/has/-/has-1.0.3.tgz#722d7cbfc1f6aa8241f16dd814e011e1f41e8796" + integrity sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw== + dependencies: + function-bind "^1.1.1" + +is-binary-path@~2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/is-binary-path/-/is-binary-path-2.1.0.tgz#ea1f7f3b80f064236e83470f86c09c254fb45b09" + integrity sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw== + dependencies: + binary-extensions "^2.0.0" + +is-core-module@^2.9.0: + version "2.9.0" + resolved "https://registry.yarnpkg.com/is-core-module/-/is-core-module-2.9.0.tgz#e1c34429cd51c6dd9e09e0799e396e27b19a9c69" + integrity sha512-+5FPy5PnwmO3lvfMb0AsoPaBG+5KHUI0wYFXOtYPnVVVspTFUuMZNfNaNVRt3FZadstu2c8x23vykRW/NBoU6A== + dependencies: + has "^1.0.3" + +is-extglob@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/is-extglob/-/is-extglob-2.1.1.tgz#a88c02535791f02ed37c76a1b9ea9773c833f8c2" + integrity sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ== + +is-fullwidth-code-point@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz#f116f8064fe90b3f7844a38997c0b75051269f1d" + integrity sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg== + +is-glob@^4.0.1, is-glob@^4.0.3, is-glob@~4.0.1: + version "4.0.3" + resolved "https://registry.yarnpkg.com/is-glob/-/is-glob-4.0.3.tgz#64f61e42cbbb2eec2071a9dac0b28ba1e65d5084" + integrity sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg== + dependencies: + is-extglob "^2.1.1" + +is-number@^7.0.0: + version "7.0.0" + resolved "https://registry.yarnpkg.com/is-number/-/is-number-7.0.0.tgz#7535345b896734d5f80c4d06c50955527a14f12b" + integrity sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng== + +lilconfig@^2.0.5: + version "2.0.6" + resolved "https://registry.yarnpkg.com/lilconfig/-/lilconfig-2.0.6.tgz#32a384558bd58af3d4c6e077dd1ad1d397bc69d4" + integrity sha512-9JROoBW7pobfsx+Sq2JsASvCo6Pfo6WWoUW79HuB1BCoBXD4PLWJPqDF6fNj67pqBYTbAHkE57M1kS/+L1neOg== + +locate-path@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/locate-path/-/locate-path-5.0.0.tgz#1afba396afd676a6d42504d0a67a3a7eb9f62aa0" + integrity sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g== + dependencies: + p-locate "^4.1.0" + +merge2@^1.3.0: + version "1.4.1" + resolved "https://registry.yarnpkg.com/merge2/-/merge2-1.4.1.tgz#4368892f885e907455a6fd7dc55c0c9d404990ae" + integrity sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg== + +micromatch@^4.0.4: + version "4.0.5" + resolved "https://registry.yarnpkg.com/micromatch/-/micromatch-4.0.5.tgz#bc8999a7cbbf77cdc89f132f6e467051b49090c6" + integrity sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA== + dependencies: + braces "^3.0.2" + picomatch "^2.3.1" + +minimist@^1.2.6: + version "1.2.6" + resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.6.tgz#8637a5b759ea0d6e98702cfb3a9283323c93af44" + integrity sha512-Jsjnk4bw3YJqYzbdyBiNsPWHPfO++UGG749Cxs6peCu5Xg4nrena6OVxOYxrQTqww0Jmwt+Ref8rggumkTLz9Q== + +nanoid@^3.3.4: + version "3.3.4" + resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.4.tgz#730b67e3cd09e2deacf03c027c81c9d9dbc5e8ab" + integrity sha512-MqBkQh/OHTS2egovRtLk45wEyNXwF+cokD+1YPf9u5VfJiRdAiRwB2froX5Co9Rh20xs4siNPm8naNotSD6RBw== + +normalize-path@^3.0.0, normalize-path@~3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/normalize-path/-/normalize-path-3.0.0.tgz#0dcd69ff23a1c9b11fd0978316644a0388216a65" + integrity sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA== + +object-hash@^3.0.0: + version "3.0.0" + resolved "https://registry.yarnpkg.com/object-hash/-/object-hash-3.0.0.tgz#73f97f753e7baffc0e2cc9d6e079079744ac82e9" + integrity sha512-RSn9F68PjH9HqtltsSnqYC1XXoWe9Bju5+213R98cNGttag9q9yAOTzdbsqvIa7aNm5WffBZFpWYr2aWrklWAw== + +p-limit@^2.2.0: + version "2.3.0" + resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-2.3.0.tgz#3dd33c647a214fdfffd835933eb086da0dc21db1" + integrity sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w== + dependencies: + p-try "^2.0.0" + +p-locate@^4.1.0: + version "4.1.0" + resolved "https://registry.yarnpkg.com/p-locate/-/p-locate-4.1.0.tgz#a3428bb7088b3a60292f66919278b7c297ad4f07" + integrity sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A== + dependencies: + p-limit "^2.2.0" + +p-try@^2.0.0: + version "2.2.0" + resolved "https://registry.yarnpkg.com/p-try/-/p-try-2.2.0.tgz#cb2868540e313d61de58fafbe35ce9004d5540e6" + integrity sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ== + +path-exists@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/path-exists/-/path-exists-4.0.0.tgz#513bdbe2d3b95d7762e8c1137efa195c6c61b5b3" + integrity sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w== + +path-parse@^1.0.7: + version "1.0.7" + resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.7.tgz#fbc114b60ca42b30d9daf5858e4bd68bbedb6735" + integrity sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw== + +picocolors@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/picocolors/-/picocolors-1.0.0.tgz#cb5bdc74ff3f51892236eaf79d68bc44564ab81c" + integrity sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ== + +picomatch@^2.0.4, picomatch@^2.2.1, picomatch@^2.3.1: + version "2.3.1" + resolved "https://registry.yarnpkg.com/picomatch/-/picomatch-2.3.1.tgz#3ba3833733646d9d3e4995946c1365a67fb07a42" + integrity sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA== + +pify@^2.3.0: + version "2.3.0" + resolved "https://registry.yarnpkg.com/pify/-/pify-2.3.0.tgz#ed141a6ac043a849ea588498e7dca8b15330e90c" + integrity sha512-udgsAY+fTnvv7kI7aaxbqwWNb0AHiB0qBO89PZKPkoTmGOgdbrHDKD+0B2X4uTfJ/FT1R09r9gTsjUjNJotuog== + +pngjs@^5.0.0: + version "5.0.0" + resolved "https://registry.yarnpkg.com/pngjs/-/pngjs-5.0.0.tgz#e79dd2b215767fd9c04561c01236df960bce7fbb" + integrity sha512-40QW5YalBNfQo5yRYmiw7Yz6TKKVr3h6970B2YE+3fQpsWcrbj1PzJgxeJ19DRQjhMbKPIuMY8rFaXc8moolVw== + +postcss-import@^14.1.0: + version "14.1.0" + resolved "https://registry.yarnpkg.com/postcss-import/-/postcss-import-14.1.0.tgz#a7333ffe32f0b8795303ee9e40215dac922781f0" + integrity sha512-flwI+Vgm4SElObFVPpTIT7SU7R3qk2L7PyduMcokiaVKuWv9d/U+Gm/QAd8NDLuykTWTkcrjOeD2Pp1rMeBTGw== + dependencies: + postcss-value-parser "^4.0.0" + read-cache "^1.0.0" + resolve "^1.1.7" + +postcss-js@^4.0.0: + version "4.0.0" + resolved "https://registry.yarnpkg.com/postcss-js/-/postcss-js-4.0.0.tgz#31db79889531b80dc7bc9b0ad283e418dce0ac00" + integrity sha512-77QESFBwgX4irogGVPgQ5s07vLvFqWr228qZY+w6lW599cRlK/HmnlivnnVUxkjHnCu4J16PDMHcH+e+2HbvTQ== + dependencies: + camelcase-css "^2.0.1" + +postcss-load-config@^3.1.4: + version "3.1.4" + resolved "https://registry.yarnpkg.com/postcss-load-config/-/postcss-load-config-3.1.4.tgz#1ab2571faf84bb078877e1d07905eabe9ebda855" + integrity sha512-6DiM4E7v4coTE4uzA8U//WhtPwyhiim3eyjEMFCnUpzbrkK9wJHgKDT2mR+HbtSrd/NubVaYTOpSpjUl8NQeRg== + dependencies: + lilconfig "^2.0.5" + yaml "^1.10.2" + +postcss-nested@5.0.6: + version "5.0.6" + resolved "https://registry.yarnpkg.com/postcss-nested/-/postcss-nested-5.0.6.tgz#466343f7fc8d3d46af3e7dba3fcd47d052a945bc" + integrity sha512-rKqm2Fk0KbA8Vt3AdGN0FB9OBOMDVajMG6ZCf/GoHgdxUJ4sBFp0A/uMIRm+MJUdo33YXEtjqIz8u7DAp8B7DA== + dependencies: + postcss-selector-parser "^6.0.6" + +postcss-selector-parser@^6.0.10, postcss-selector-parser@^6.0.6: + version "6.0.10" + resolved "https://registry.yarnpkg.com/postcss-selector-parser/-/postcss-selector-parser-6.0.10.tgz#79b61e2c0d1bfc2602d549e11d0876256f8df88d" + integrity sha512-IQ7TZdoaqbT+LCpShg46jnZVlhWD2w6iQYAcYXfHARZ7X1t/UGhhceQDs5X0cGqKvYlHNOuv7Oa1xmb0oQuA3w== + dependencies: + cssesc "^3.0.0" + util-deprecate "^1.0.2" + +postcss-value-parser@^4.0.0, postcss-value-parser@^4.2.0: + version "4.2.0" + resolved "https://registry.yarnpkg.com/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz#723c09920836ba6d3e5af019f92bc0971c02e514" + integrity sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ== + +postcss@^8.4.14: + version "8.4.14" + resolved "https://registry.yarnpkg.com/postcss/-/postcss-8.4.14.tgz#ee9274d5622b4858c1007a74d76e42e56fd21caf" + integrity sha512-E398TUmfAYFPBSdzgeieK2Y1+1cpdxJx8yXbK/m57nRhKSmk1GB2tO4lbLBtlkfPQTDKfe4Xqv1ASWPpayPEig== + dependencies: + nanoid "^3.3.4" + picocolors "^1.0.0" + source-map-js "^1.0.2" + +preact@^10.10.0: + version "10.10.0" + resolved "https://registry.yarnpkg.com/preact/-/preact-10.10.0.tgz#7434750a24b59dae1957d95dc0aa47a4a8e9a180" + integrity sha512-fszkg1iJJjq68I4lI8ZsmBiaoQiQHbxf1lNq+72EmC/mZOsFF5zn3k1yv9QGoFgIXzgsdSKtYymLJsrJPoamjQ== + +qrcode@^1.5.0: + version "1.5.0" + resolved "https://registry.yarnpkg.com/qrcode/-/qrcode-1.5.0.tgz#95abb8a91fdafd86f8190f2836abbfc500c72d1b" + integrity sha512-9MgRpgVc+/+47dFvQeD6U2s0Z92EsKzcHogtum4QB+UNd025WOJSHvn/hjk9xmzj7Stj95CyUAs31mrjxliEsQ== + dependencies: + dijkstrajs "^1.0.1" + encode-utf8 "^1.0.3" + pngjs "^5.0.0" + yargs "^15.3.1" + +queue-microtask@^1.2.2: + version "1.2.3" + resolved "https://registry.yarnpkg.com/queue-microtask/-/queue-microtask-1.2.3.tgz#4929228bbc724dfac43e0efb058caf7b6cfb6243" + integrity sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A== + +quick-lru@^5.1.1: + version "5.1.1" + resolved "https://registry.yarnpkg.com/quick-lru/-/quick-lru-5.1.1.tgz#366493e6b3e42a3a6885e2e99d18f80fb7a8c932" + integrity sha512-WuyALRjWPDGtt/wzJiadO5AXY+8hZ80hVpe6MyivgraREW751X3SbhRvG3eLKOYN+8VEvqLcf3wdnt44Z4S4SA== + +read-cache@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/read-cache/-/read-cache-1.0.0.tgz#e664ef31161166c9751cdbe8dbcf86b5fb58f774" + integrity sha512-Owdv/Ft7IjOgm/i0xvNDZ1LrRANRfew4b2prF3OWMQLxLfu3bS8FVhCsrSCMK4lR56Y9ya+AThoTpDCTxCmpRA== + dependencies: + pify "^2.3.0" + +readdirp@~3.6.0: + version "3.6.0" + resolved "https://registry.yarnpkg.com/readdirp/-/readdirp-3.6.0.tgz#74a370bd857116e245b29cc97340cd431a02a6c7" + integrity sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA== + dependencies: + picomatch "^2.2.1" + +require-directory@^2.1.1: + version "2.1.1" + resolved "https://registry.yarnpkg.com/require-directory/-/require-directory-2.1.1.tgz#8c64ad5fd30dab1c976e2344ffe7f792a6a6df42" + integrity sha1-jGStX9MNqxyXbiNE/+f3kqam30I= + +require-main-filename@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/require-main-filename/-/require-main-filename-2.0.0.tgz#d0b329ecc7cc0f61649f62215be69af54aa8989b" + integrity sha512-NKN5kMDylKuldxYLSUfrbo5Tuzh4hd+2E8NPPX02mZtn1VuREQToYe/ZdlJy+J3uCpfaiGF05e7B8W0iXbQHmg== + +resolve@^1.1.7, resolve@^1.22.1: + version "1.22.1" + resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.22.1.tgz#27cb2ebb53f91abb49470a928bba7558066ac177" + integrity sha512-nBpuuYuY5jFsli/JIs1oldw6fOQCBioohqWZg/2hiaOybXOft4lonv85uDOKXdf8rhyK159cxU5cDcK/NKk8zw== + dependencies: + is-core-module "^2.9.0" + path-parse "^1.0.7" + supports-preserve-symlinks-flag "^1.0.0" + +reusify@^1.0.4: + version "1.0.4" + resolved "https://registry.yarnpkg.com/reusify/-/reusify-1.0.4.tgz#90da382b1e126efc02146e90845a88db12925d76" + integrity sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw== + +run-parallel@^1.1.9: + version "1.2.0" + resolved "https://registry.yarnpkg.com/run-parallel/-/run-parallel-1.2.0.tgz#66d1368da7bdf921eb9d95bd1a9229e7f21a43ee" + integrity sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA== + dependencies: + queue-microtask "^1.2.2" + +set-blocking@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/set-blocking/-/set-blocking-2.0.0.tgz#045f9782d011ae9a6803ddd382b24392b3d890f7" + integrity sha1-BF+XgtARrppoA93TgrJDkrPYkPc= + +source-map-js@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/source-map-js/-/source-map-js-1.0.2.tgz#adbc361d9c62df380125e7f161f71c826f1e490c" + integrity sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw== + +string-width@^4.1.0, string-width@^4.2.0, string-width@^4.2.3: + version "4.2.3" + resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" + integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== + dependencies: + emoji-regex "^8.0.0" + is-fullwidth-code-point "^3.0.0" + strip-ansi "^6.0.1" + +strip-ansi@^6.0.0, strip-ansi@^6.0.1: + version "6.0.1" + resolved "https://registry.yarnpkg.com/strip-ansi/-/strip-ansi-6.0.1.tgz#9e26c63d30f53443e9489495b2105d37b67a85d9" + integrity sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A== + dependencies: + ansi-regex "^5.0.1" + +supports-preserve-symlinks-flag@^1.0.0: + version "1.0.0" + resolved "https://registry.yarnpkg.com/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz#6eda4bd344a3c94aea376d4cc31bc77311039e09" + integrity sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w== + +tailwindcss@^3.1.6: + version "3.1.6" + resolved "https://registry.yarnpkg.com/tailwindcss/-/tailwindcss-3.1.6.tgz#bcb719357776c39e6376a8d84e9834b2b19a49f1" + integrity sha512-7skAOY56erZAFQssT1xkpk+kWt2NrO45kORlxFPXUt3CiGsVPhH1smuH5XoDH6sGPXLyBv+zgCKA2HWBsgCytg== + dependencies: + arg "^5.0.2" + chokidar "^3.5.3" + color-name "^1.1.4" + detective "^5.2.1" + didyoumean "^1.2.2" + dlv "^1.1.3" + fast-glob "^3.2.11" + glob-parent "^6.0.2" + is-glob "^4.0.3" + lilconfig "^2.0.5" + normalize-path "^3.0.0" + object-hash "^3.0.0" + picocolors "^1.0.0" + postcss "^8.4.14" + postcss-import "^14.1.0" + postcss-js "^4.0.0" + postcss-load-config "^3.1.4" + postcss-nested "5.0.6" + postcss-selector-parser "^6.0.10" + postcss-value-parser "^4.2.0" + quick-lru "^5.1.1" + resolve "^1.22.1" + +to-regex-range@^5.0.1: + version "5.0.1" + resolved "https://registry.yarnpkg.com/to-regex-range/-/to-regex-range-5.0.1.tgz#1648c44aae7c8d988a326018ed72f5b4dd0392e4" + integrity sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ== + dependencies: + is-number "^7.0.0" + +typescript@>=3.0.1, typescript@^4.7.4: + version "4.7.4" + resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.7.4.tgz#1a88596d1cf47d59507a1bcdfb5b9dfe4d488235" + integrity sha512-C0WQT0gezHuw6AdY1M2jxUO83Rjf0HP7Sk1DtXj6j1EwkQNZrHAg2XPWlq62oqEhYvONq5pkC2Y9oPljWToLmQ== + +util-deprecate@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" + integrity sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw== + +which-module@^2.0.0: + version "2.0.0" + resolved "https://registry.yarnpkg.com/which-module/-/which-module-2.0.0.tgz#d9ef07dce77b9902b8a3a8fa4b31c3e3f7e6e87a" + integrity sha1-2e8H3Od7mQK4o6j6SzHD4/fm6Ho= + +wrap-ansi@^6.2.0: + version "6.2.0" + resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-6.2.0.tgz#e9393ba07102e6c91a3b221478f0257cd2856e53" + integrity sha512-r6lPcBGxZXlIcymEu7InxDMhdW0KDxpLgoFLcguasxCaJ/SOIZwINatK9KY/tf+ZrlywOKU0UDj3ATXUBfxJXA== + dependencies: + ansi-styles "^4.0.0" + string-width "^4.1.0" + strip-ansi "^6.0.0" + +wrap-ansi@^7.0.0: + version "7.0.0" + resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43" + integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q== + dependencies: + ansi-styles "^4.0.0" + string-width "^4.1.0" + strip-ansi "^6.0.0" + +xtend@^4.0.2: + version "4.0.2" + resolved "https://registry.yarnpkg.com/xtend/-/xtend-4.0.2.tgz#bb72779f5fa465186b1f438f674fa347fdb5db54" + integrity sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ== + +xterm-addon-fit@^0.7.0: + version "0.7.0" + resolved "https://registry.yarnpkg.com/xterm-addon-fit/-/xterm-addon-fit-0.7.0.tgz#b8ade6d96e63b47443862088f6670b49fb752c6a" + integrity sha512-tQgHGoHqRTgeROPnvmtEJywLKoC/V9eNs4bLLz7iyJr1aW/QFzRwfd3MGiJ6odJd9xEfxcW36/xRU47JkD5NKQ== + +xterm-addon-web-links@^0.8.0: + version "0.8.0" + resolved "https://registry.yarnpkg.com/xterm-addon-web-links/-/xterm-addon-web-links-0.8.0.tgz#2cb1d57129271022569208578b0bf4774e7e6ea9" + integrity sha512-J4tKngmIu20ytX9SEJjAP3UGksah7iALqBtfTwT9ZnmFHVplCumYQsUJfKuS+JwMhjsjH61YXfndenLNvjRrEw== + +xterm@^5.1.0: + version "5.1.0" + resolved "https://registry.yarnpkg.com/xterm/-/xterm-5.1.0.tgz#3e160d60e6801c864b55adf19171c49d2ff2b4fc" + integrity sha512-LovENH4WDzpwynj+OTkLyZgJPeDom9Gra4DMlGAgz6pZhIDCQ+YuO7yfwanY+gVbn/mmZIStNOnVRU/ikQuAEQ== + +y18n@^4.0.0: + version "4.0.3" + resolved "https://registry.yarnpkg.com/y18n/-/y18n-4.0.3.tgz#b5f259c82cd6e336921efd7bfd8bf560de9eeedf" + integrity sha512-JKhqTOwSrqNA1NY5lSztJ1GrBiUodLMmIZuLiDaMRJ+itFd+ABVE8XBjOvIWL+rSqNDC74LCSFmlb/U4UZ4hJQ== + +y18n@^5.0.5: + version "5.0.8" + resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.8.tgz#7f4934d0f7ca8c56f95314939ddcd2dd91ce1d55" + integrity sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA== + +yaml@^1.10.2: + version "1.10.2" + resolved "https://registry.yarnpkg.com/yaml/-/yaml-1.10.2.tgz#2301c5ffbf12b467de8da2333a459e29e7920e4b" + integrity sha512-r3vXyErRCYJ7wg28yvBY5VSoAF8ZvlcW9/BwUzEtUsjvX/DKs24dIkuwjtuprwJJHsbyUbLApepYTR1BN4uHrg== + +yargs-parser@^18.1.2: + version "18.1.3" + resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-18.1.3.tgz#be68c4975c6b2abf469236b0c870362fab09a7b0" + integrity sha512-o50j0JeToy/4K6OZcaQmW6lyXXKhq7csREXcDwk2omFPJEwUNOVtJKvmDr9EI1fAJZUyZcRF7kxGBWmRXudrCQ== + dependencies: + camelcase "^5.0.0" + decamelize "^1.2.0" + +yargs-parser@^21.0.0: + version "21.1.1" + resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-21.1.1.tgz#9096bceebf990d21bb31fa9516e0ede294a77d35" + integrity sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw== + +yargs@^15.3.1: + version "15.4.1" + resolved "https://registry.yarnpkg.com/yargs/-/yargs-15.4.1.tgz#0d87a16de01aee9d8bec2bfbf74f67851730f4f8" + integrity sha512-aePbxDmcYW++PaqBsJ+HYUFwCdv4LVvdnhBy78E57PIor8/OVvhMrADFFEDh8DHDFRv/O9i3lPhsENjO7QX0+A== + dependencies: + cliui "^6.0.0" + decamelize "^1.2.0" + find-up "^4.1.0" + get-caller-file "^2.0.1" + require-directory "^2.1.1" + require-main-filename "^2.0.0" + set-blocking "^2.0.0" + string-width "^4.2.0" + which-module "^2.0.0" + y18n "^4.0.0" + yargs-parser "^18.1.2" + +yargs@^17.2.1: + version "17.5.1" + resolved "https://registry.yarnpkg.com/yargs/-/yargs-17.5.1.tgz#e109900cab6fcb7fd44b1d8249166feb0b36e58e" + integrity sha512-t6YAJcxDkNX7NFYiVtKvWUz8l+PaKTLiL63mJYWR2GnHq2gjEWISzsLp9wg3aY36dY1j+gfIEL3pIF+XlJJfbA== + dependencies: + cliui "^7.0.2" + escalade "^3.1.1" + get-caller-file "^2.0.5" + require-directory "^2.1.1" + string-width "^4.2.3" + y18n "^5.0.5" + yargs-parser "^21.0.0" diff --git a/cmd/tsshd/tsshd.go b/cmd/tsshd/tsshd.go index 1ec09a0d4..950eb661c 100644 --- a/cmd/tsshd/tsshd.go +++ b/cmd/tsshd/tsshd.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ignore - -// The tsshd binary was an experimental SSH server that accepts connections -// from anybody on the same Tailscale network. -// -// Its functionality moved into tailscaled. -// -// See https://github.com/tailscale/tailscale/issues/3802 -package main +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +// The tsshd binary was an experimental SSH server that accepts connections +// from anybody on the same Tailscale network. +// +// Its functionality moved into tailscaled. +// +// See https://github.com/tailscale/tailscale/issues/3802 +package main diff --git a/control/controlbase/conn.go b/control/controlbase/conn.go index b6fc53b3a..dc22212e8 100644 --- a/control/controlbase/conn.go +++ b/control/controlbase/conn.go @@ -1,408 +1,408 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package controlbase implements the base transport of the Tailscale -// 2021 control protocol. -// -// The base transport implements Noise IK, instantiated with -// Curve25519, ChaCha20Poly1305 and BLAKE2s. -package controlbase - -import ( - "crypto/cipher" - "encoding/binary" - "fmt" - "net" - "sync" - "time" - - "golang.org/x/crypto/blake2s" - chp "golang.org/x/crypto/chacha20poly1305" - "tailscale.com/types/key" -) - -const ( - // maxMessageSize is the maximum size of a protocol frame on the - // wire, including header and payload. - maxMessageSize = 4096 - // maxCiphertextSize is the maximum amount of ciphertext bytes - // that one protocol frame can carry, after framing. - maxCiphertextSize = maxMessageSize - 3 - // maxPlaintextSize is the maximum amount of plaintext bytes that - // one protocol frame can carry, after encryption and framing. - maxPlaintextSize = maxCiphertextSize - chp.Overhead -) - -// A Conn is a secured Noise connection. It implements the net.Conn -// interface, with the unusual trait that any write error (including a -// SetWriteDeadline induced i/o timeout) causes all future writes to -// fail. -type Conn struct { - conn net.Conn - version uint16 - peer key.MachinePublic - handshakeHash [blake2s.Size]byte - rx rxState - tx txState -} - -// rxState is all the Conn state that Read uses. -type rxState struct { - sync.Mutex - cipher cipher.AEAD - nonce nonce - buf *maxMsgBuffer // or nil when reads exhausted - n int // number of valid bytes in buf - next int // offset of next undecrypted packet - plaintext []byte // slice into buf of decrypted bytes - hdrBuf [headerLen]byte // small buffer used when buf is nil -} - -// txState is all the Conn state that Write uses. -type txState struct { - sync.Mutex - cipher cipher.AEAD - nonce nonce - err error // records the first partial write error for all future calls -} - -// ProtocolVersion returns the protocol version that was used to -// establish this Conn. -func (c *Conn) ProtocolVersion() int { - return int(c.version) -} - -// HandshakeHash returns the Noise handshake hash for the connection, -// which can be used to bind other messages to this connection -// (i.e. to ensure that the message wasn't replayed from a different -// connection). -func (c *Conn) HandshakeHash() [blake2s.Size]byte { - return c.handshakeHash -} - -// Peer returns the peer's long-term public key. -func (c *Conn) Peer() key.MachinePublic { - return c.peer -} - -// readNLocked reads into c.rx.buf until buf contains at least total -// bytes. Returns a slice of the total bytes in rxBuf, or an -// error if fewer than total bytes are available. -// -// It may be called with a nil c.rx.buf only if total == headerLen. -// -// On success, c.rx.buf will be non-nil. -func (c *Conn) readNLocked(total int) ([]byte, error) { - if total > maxMessageSize { - return nil, errReadTooBig{total} - } - for { - if total <= c.rx.n { - return c.rx.buf[:total], nil - } - var n int - var err error - if c.rx.buf == nil { - if c.rx.n != 0 || total != headerLen { - panic("unexpected") - } - // Optimization to reduce memory usage. - // Most connections are blocked forever waiting for - // a read, so we don't want c.rx.buf to be allocated until - // we know there's data to read. Instead, when we're - // waiting for data to arrive here, read into the - // 3 byte hdrBuf: - n, err = c.conn.Read(c.rx.hdrBuf[:]) - if n > 0 { - c.rx.buf = getMaxMsgBuffer() - copy(c.rx.buf[:], c.rx.hdrBuf[:n]) - } - } else { - n, err = c.conn.Read(c.rx.buf[c.rx.n:]) - } - c.rx.n += n - if err != nil { - return nil, err - } - } -} - -// decryptLocked decrypts msg (which is header+ciphertext) in-place -// and sets c.rx.plaintext to the decrypted bytes. -func (c *Conn) decryptLocked(msg []byte) (err error) { - if msgType := msg[0]; msgType != msgTypeRecord { - return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord) - } - // We don't check the length field here, because the caller - // already did in order to figure out how big the msg slice should - // be. - ciphertext := msg[headerLen:] - - if !c.rx.nonce.Valid() { - return errCipherExhausted{} - } - - c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil) - c.rx.nonce.Increment() - - if err != nil { - // Once a decryption has failed, our Conn is no longer - // synchronized with our peer. Nuke the cipher state to be - // safe, so that no further decryptions are attempted. Future - // read attempts will return net.ErrClosed. - c.rx.cipher = nil - } - return err -} - -// encryptLocked encrypts plaintext into buf (including the -// packet header) and returns a slice of the ciphertext, or an error -// if the cipher is exhausted (i.e. can no longer be used safely). -func (c *Conn) encryptLocked(plaintext []byte, buf *maxMsgBuffer) ([]byte, error) { - if !c.tx.nonce.Valid() { - // Received 2^64-1 messages on this cipher state. Connection - // is no longer usable. - return nil, errCipherExhausted{} - } - - buf[0] = msgTypeRecord - binary.BigEndian.PutUint16(buf[1:headerLen], uint16(len(plaintext)+chp.Overhead)) - ret := c.tx.cipher.Seal(buf[:headerLen], c.tx.nonce[:], plaintext, nil) - c.tx.nonce.Increment() - - return ret, nil -} - -// wholeMessageLocked returns a slice of one whole Noise transport -// message from c.rx.buf, if one whole message is available, and -// advances the read state to the next Noise message in the -// buffer. Returns nil without advancing read state if there isn't one -// whole message in c.rx.buf. -func (c *Conn) wholeMessageLocked() []byte { - available := c.rx.n - c.rx.next - if available < headerLen { - return nil - } - bs := c.rx.buf[c.rx.next:c.rx.n] - totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3])) - if len(bs) < totalSize { - return nil - } - c.rx.next += totalSize - return bs[:totalSize] -} - -// decryptOneLocked decrypts one Noise transport message, reading from -// c.conn as needed, and sets c.rx.plaintext to point to the decrypted -// bytes. c.rx.plaintext is only valid if err == nil. -func (c *Conn) decryptOneLocked() error { - c.rx.plaintext = nil - - // Fast path: do we have one whole ciphertext frame buffered - // already? - if bs := c.wholeMessageLocked(); bs != nil { - return c.decryptLocked(bs) - } - - if c.rx.next != 0 { - // To simplify the read logic, move the remainder of the - // buffered bytes back to the head of the buffer, so we can - // grow it without worrying about wraparound. - c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n]) - c.rx.next = 0 - } - - // Return our buffer to the pool if it's empty, lest we be - // blocked in a long Read call, reading the 3 byte header. We - // don't to keep that buffer unnecessarily alive. - if c.rx.n == 0 && c.rx.next == 0 && c.rx.buf != nil { - bufPool.Put(c.rx.buf) - c.rx.buf = nil - } - - bs, err := c.readNLocked(headerLen) - if err != nil { - return err - } - // The rest of the header (besides the length field) gets verified - // in decryptLocked, not here. - messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3])) - bs, err = c.readNLocked(messageLen) - if err != nil { - return err - } - - c.rx.next = len(bs) - - return c.decryptLocked(bs) -} - -// Read implements io.Reader. -func (c *Conn) Read(bs []byte) (int, error) { - c.rx.Lock() - defer c.rx.Unlock() - - if c.rx.cipher == nil { - return 0, net.ErrClosed - } - // If no plaintext is buffered, decrypt incoming frames until we - // have some plaintext. Zero-byte Noise frames are allowed in this - // protocol, which is why we have to loop here rather than decrypt - // a single additional frame. - for len(c.rx.plaintext) == 0 { - if err := c.decryptOneLocked(); err != nil { - return 0, err - } - } - n := copy(bs, c.rx.plaintext) - c.rx.plaintext = c.rx.plaintext[n:] - - // Lose slice's underlying array pointer to unneeded memory so - // GC can collect more. - if len(c.rx.plaintext) == 0 { - c.rx.plaintext = nil - } - return n, nil -} - -// Write implements io.Writer. -func (c *Conn) Write(bs []byte) (n int, err error) { - c.tx.Lock() - defer c.tx.Unlock() - - if c.tx.err != nil { - return 0, c.tx.err - } - defer func() { - if err != nil { - // All write errors are fatal for this conn, so clear the - // cipher state whenever an error happens. - c.tx.cipher = nil - } - if c.tx.err == nil { - // Only set c.tx.err if not nil so that we can return one - // error on the first failure, and a different one for - // subsequent calls. See the error handling around Write - // below for why. - c.tx.err = err - } - }() - - if c.tx.cipher == nil { - return 0, net.ErrClosed - } - - buf := getMaxMsgBuffer() - defer bufPool.Put(buf) - - var sent int - for len(bs) > 0 { - toSend := bs - if len(toSend) > maxPlaintextSize { - toSend = bs[:maxPlaintextSize] - } - bs = bs[len(toSend):] - - ciphertext, err := c.encryptLocked(toSend, buf) - if err != nil { - return sent, err - } - if _, err := c.conn.Write(ciphertext); err != nil { - // Return the raw error on the Write that actually - // failed. For future writes, return that error wrapped in - // a desync error. - c.tx.err = errPartialWrite{err} - return sent, err - } - sent += len(toSend) - } - return sent, nil -} - -// Close implements io.Closer. -func (c *Conn) Close() error { - closeErr := c.conn.Close() // unblocks any waiting reads or writes - - // Remove references to live cipher state. Strictly speaking this - // is unnecessary, but we want to try and hand the active cipher - // state to the garbage collector promptly, to preserve perfect - // forward secrecy as much as we can. - c.rx.Lock() - c.rx.cipher = nil - c.rx.Unlock() - c.tx.Lock() - c.tx.cipher = nil - c.tx.Unlock() - return closeErr -} - -func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() } -func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } -func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } -func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } -func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } - -// errCipherExhausted is the error returned when we run out of nonces -// on a cipher. -type errCipherExhausted struct{} - -func (errCipherExhausted) Error() string { - return "cipher exhausted, no more nonces available for current key" -} -func (errCipherExhausted) Timeout() bool { return false } -func (errCipherExhausted) Temporary() bool { return false } - -// errPartialWrite is the error returned when the cipher state has -// become unusable due to a past partial write. -type errPartialWrite struct { - err error -} - -func (e errPartialWrite) Error() string { - return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err) -} -func (e errPartialWrite) Unwrap() error { return e.err } -func (e errPartialWrite) Temporary() bool { return false } -func (e errPartialWrite) Timeout() bool { return false } - -// errReadTooBig is the error returned when the peer sent an -// unacceptably large Noise frame. -type errReadTooBig struct { - requested int -} - -func (e errReadTooBig) Error() string { - return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested) -} -func (e errReadTooBig) Temporary() bool { - // permanent error because this error only occurs when our peer - // sends us a frame so large we're unwilling to ever decode it. - return false -} -func (e errReadTooBig) Timeout() bool { return false } - -type nonce [chp.NonceSize]byte - -func (n *nonce) Valid() bool { - return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce -} - -func (n *nonce) Increment() { - if !n.Valid() { - panic("increment of invalid nonce") - } - binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:])) -} - -type maxMsgBuffer [maxMessageSize]byte - -// bufPool holds the temporary buffers for Conn.Read & Write. -var bufPool = &sync.Pool{ - New: func() any { - return new(maxMsgBuffer) - }, -} - -func getMaxMsgBuffer() *maxMsgBuffer { - return bufPool.Get().(*maxMsgBuffer) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package controlbase implements the base transport of the Tailscale +// 2021 control protocol. +// +// The base transport implements Noise IK, instantiated with +// Curve25519, ChaCha20Poly1305 and BLAKE2s. +package controlbase + +import ( + "crypto/cipher" + "encoding/binary" + "fmt" + "net" + "sync" + "time" + + "golang.org/x/crypto/blake2s" + chp "golang.org/x/crypto/chacha20poly1305" + "tailscale.com/types/key" +) + +const ( + // maxMessageSize is the maximum size of a protocol frame on the + // wire, including header and payload. + maxMessageSize = 4096 + // maxCiphertextSize is the maximum amount of ciphertext bytes + // that one protocol frame can carry, after framing. + maxCiphertextSize = maxMessageSize - 3 + // maxPlaintextSize is the maximum amount of plaintext bytes that + // one protocol frame can carry, after encryption and framing. + maxPlaintextSize = maxCiphertextSize - chp.Overhead +) + +// A Conn is a secured Noise connection. It implements the net.Conn +// interface, with the unusual trait that any write error (including a +// SetWriteDeadline induced i/o timeout) causes all future writes to +// fail. +type Conn struct { + conn net.Conn + version uint16 + peer key.MachinePublic + handshakeHash [blake2s.Size]byte + rx rxState + tx txState +} + +// rxState is all the Conn state that Read uses. +type rxState struct { + sync.Mutex + cipher cipher.AEAD + nonce nonce + buf *maxMsgBuffer // or nil when reads exhausted + n int // number of valid bytes in buf + next int // offset of next undecrypted packet + plaintext []byte // slice into buf of decrypted bytes + hdrBuf [headerLen]byte // small buffer used when buf is nil +} + +// txState is all the Conn state that Write uses. +type txState struct { + sync.Mutex + cipher cipher.AEAD + nonce nonce + err error // records the first partial write error for all future calls +} + +// ProtocolVersion returns the protocol version that was used to +// establish this Conn. +func (c *Conn) ProtocolVersion() int { + return int(c.version) +} + +// HandshakeHash returns the Noise handshake hash for the connection, +// which can be used to bind other messages to this connection +// (i.e. to ensure that the message wasn't replayed from a different +// connection). +func (c *Conn) HandshakeHash() [blake2s.Size]byte { + return c.handshakeHash +} + +// Peer returns the peer's long-term public key. +func (c *Conn) Peer() key.MachinePublic { + return c.peer +} + +// readNLocked reads into c.rx.buf until buf contains at least total +// bytes. Returns a slice of the total bytes in rxBuf, or an +// error if fewer than total bytes are available. +// +// It may be called with a nil c.rx.buf only if total == headerLen. +// +// On success, c.rx.buf will be non-nil. +func (c *Conn) readNLocked(total int) ([]byte, error) { + if total > maxMessageSize { + return nil, errReadTooBig{total} + } + for { + if total <= c.rx.n { + return c.rx.buf[:total], nil + } + var n int + var err error + if c.rx.buf == nil { + if c.rx.n != 0 || total != headerLen { + panic("unexpected") + } + // Optimization to reduce memory usage. + // Most connections are blocked forever waiting for + // a read, so we don't want c.rx.buf to be allocated until + // we know there's data to read. Instead, when we're + // waiting for data to arrive here, read into the + // 3 byte hdrBuf: + n, err = c.conn.Read(c.rx.hdrBuf[:]) + if n > 0 { + c.rx.buf = getMaxMsgBuffer() + copy(c.rx.buf[:], c.rx.hdrBuf[:n]) + } + } else { + n, err = c.conn.Read(c.rx.buf[c.rx.n:]) + } + c.rx.n += n + if err != nil { + return nil, err + } + } +} + +// decryptLocked decrypts msg (which is header+ciphertext) in-place +// and sets c.rx.plaintext to the decrypted bytes. +func (c *Conn) decryptLocked(msg []byte) (err error) { + if msgType := msg[0]; msgType != msgTypeRecord { + return fmt.Errorf("received message with unexpected type %d, want %d", msgType, msgTypeRecord) + } + // We don't check the length field here, because the caller + // already did in order to figure out how big the msg slice should + // be. + ciphertext := msg[headerLen:] + + if !c.rx.nonce.Valid() { + return errCipherExhausted{} + } + + c.rx.plaintext, err = c.rx.cipher.Open(ciphertext[:0], c.rx.nonce[:], ciphertext, nil) + c.rx.nonce.Increment() + + if err != nil { + // Once a decryption has failed, our Conn is no longer + // synchronized with our peer. Nuke the cipher state to be + // safe, so that no further decryptions are attempted. Future + // read attempts will return net.ErrClosed. + c.rx.cipher = nil + } + return err +} + +// encryptLocked encrypts plaintext into buf (including the +// packet header) and returns a slice of the ciphertext, or an error +// if the cipher is exhausted (i.e. can no longer be used safely). +func (c *Conn) encryptLocked(plaintext []byte, buf *maxMsgBuffer) ([]byte, error) { + if !c.tx.nonce.Valid() { + // Received 2^64-1 messages on this cipher state. Connection + // is no longer usable. + return nil, errCipherExhausted{} + } + + buf[0] = msgTypeRecord + binary.BigEndian.PutUint16(buf[1:headerLen], uint16(len(plaintext)+chp.Overhead)) + ret := c.tx.cipher.Seal(buf[:headerLen], c.tx.nonce[:], plaintext, nil) + c.tx.nonce.Increment() + + return ret, nil +} + +// wholeMessageLocked returns a slice of one whole Noise transport +// message from c.rx.buf, if one whole message is available, and +// advances the read state to the next Noise message in the +// buffer. Returns nil without advancing read state if there isn't one +// whole message in c.rx.buf. +func (c *Conn) wholeMessageLocked() []byte { + available := c.rx.n - c.rx.next + if available < headerLen { + return nil + } + bs := c.rx.buf[c.rx.next:c.rx.n] + totalSize := headerLen + int(binary.BigEndian.Uint16(bs[1:3])) + if len(bs) < totalSize { + return nil + } + c.rx.next += totalSize + return bs[:totalSize] +} + +// decryptOneLocked decrypts one Noise transport message, reading from +// c.conn as needed, and sets c.rx.plaintext to point to the decrypted +// bytes. c.rx.plaintext is only valid if err == nil. +func (c *Conn) decryptOneLocked() error { + c.rx.plaintext = nil + + // Fast path: do we have one whole ciphertext frame buffered + // already? + if bs := c.wholeMessageLocked(); bs != nil { + return c.decryptLocked(bs) + } + + if c.rx.next != 0 { + // To simplify the read logic, move the remainder of the + // buffered bytes back to the head of the buffer, so we can + // grow it without worrying about wraparound. + c.rx.n = copy(c.rx.buf[:], c.rx.buf[c.rx.next:c.rx.n]) + c.rx.next = 0 + } + + // Return our buffer to the pool if it's empty, lest we be + // blocked in a long Read call, reading the 3 byte header. We + // don't to keep that buffer unnecessarily alive. + if c.rx.n == 0 && c.rx.next == 0 && c.rx.buf != nil { + bufPool.Put(c.rx.buf) + c.rx.buf = nil + } + + bs, err := c.readNLocked(headerLen) + if err != nil { + return err + } + // The rest of the header (besides the length field) gets verified + // in decryptLocked, not here. + messageLen := headerLen + int(binary.BigEndian.Uint16(bs[1:3])) + bs, err = c.readNLocked(messageLen) + if err != nil { + return err + } + + c.rx.next = len(bs) + + return c.decryptLocked(bs) +} + +// Read implements io.Reader. +func (c *Conn) Read(bs []byte) (int, error) { + c.rx.Lock() + defer c.rx.Unlock() + + if c.rx.cipher == nil { + return 0, net.ErrClosed + } + // If no plaintext is buffered, decrypt incoming frames until we + // have some plaintext. Zero-byte Noise frames are allowed in this + // protocol, which is why we have to loop here rather than decrypt + // a single additional frame. + for len(c.rx.plaintext) == 0 { + if err := c.decryptOneLocked(); err != nil { + return 0, err + } + } + n := copy(bs, c.rx.plaintext) + c.rx.plaintext = c.rx.plaintext[n:] + + // Lose slice's underlying array pointer to unneeded memory so + // GC can collect more. + if len(c.rx.plaintext) == 0 { + c.rx.plaintext = nil + } + return n, nil +} + +// Write implements io.Writer. +func (c *Conn) Write(bs []byte) (n int, err error) { + c.tx.Lock() + defer c.tx.Unlock() + + if c.tx.err != nil { + return 0, c.tx.err + } + defer func() { + if err != nil { + // All write errors are fatal for this conn, so clear the + // cipher state whenever an error happens. + c.tx.cipher = nil + } + if c.tx.err == nil { + // Only set c.tx.err if not nil so that we can return one + // error on the first failure, and a different one for + // subsequent calls. See the error handling around Write + // below for why. + c.tx.err = err + } + }() + + if c.tx.cipher == nil { + return 0, net.ErrClosed + } + + buf := getMaxMsgBuffer() + defer bufPool.Put(buf) + + var sent int + for len(bs) > 0 { + toSend := bs + if len(toSend) > maxPlaintextSize { + toSend = bs[:maxPlaintextSize] + } + bs = bs[len(toSend):] + + ciphertext, err := c.encryptLocked(toSend, buf) + if err != nil { + return sent, err + } + if _, err := c.conn.Write(ciphertext); err != nil { + // Return the raw error on the Write that actually + // failed. For future writes, return that error wrapped in + // a desync error. + c.tx.err = errPartialWrite{err} + return sent, err + } + sent += len(toSend) + } + return sent, nil +} + +// Close implements io.Closer. +func (c *Conn) Close() error { + closeErr := c.conn.Close() // unblocks any waiting reads or writes + + // Remove references to live cipher state. Strictly speaking this + // is unnecessary, but we want to try and hand the active cipher + // state to the garbage collector promptly, to preserve perfect + // forward secrecy as much as we can. + c.rx.Lock() + c.rx.cipher = nil + c.rx.Unlock() + c.tx.Lock() + c.tx.cipher = nil + c.tx.Unlock() + return closeErr +} + +func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() } +func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } +func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } +func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } +func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } + +// errCipherExhausted is the error returned when we run out of nonces +// on a cipher. +type errCipherExhausted struct{} + +func (errCipherExhausted) Error() string { + return "cipher exhausted, no more nonces available for current key" +} +func (errCipherExhausted) Timeout() bool { return false } +func (errCipherExhausted) Temporary() bool { return false } + +// errPartialWrite is the error returned when the cipher state has +// become unusable due to a past partial write. +type errPartialWrite struct { + err error +} + +func (e errPartialWrite) Error() string { + return fmt.Sprintf("cipher state desynchronized due to partial write (%v)", e.err) +} +func (e errPartialWrite) Unwrap() error { return e.err } +func (e errPartialWrite) Temporary() bool { return false } +func (e errPartialWrite) Timeout() bool { return false } + +// errReadTooBig is the error returned when the peer sent an +// unacceptably large Noise frame. +type errReadTooBig struct { + requested int +} + +func (e errReadTooBig) Error() string { + return fmt.Sprintf("requested read of %d bytes exceeds max allowed Noise frame size", e.requested) +} +func (e errReadTooBig) Temporary() bool { + // permanent error because this error only occurs when our peer + // sends us a frame so large we're unwilling to ever decode it. + return false +} +func (e errReadTooBig) Timeout() bool { return false } + +type nonce [chp.NonceSize]byte + +func (n *nonce) Valid() bool { + return binary.BigEndian.Uint32(n[:4]) == 0 && binary.BigEndian.Uint64(n[4:]) != invalidNonce +} + +func (n *nonce) Increment() { + if !n.Valid() { + panic("increment of invalid nonce") + } + binary.BigEndian.PutUint64(n[4:], 1+binary.BigEndian.Uint64(n[4:])) +} + +type maxMsgBuffer [maxMessageSize]byte + +// bufPool holds the temporary buffers for Conn.Read & Write. +var bufPool = &sync.Pool{ + New: func() any { + return new(maxMsgBuffer) + }, +} + +func getMaxMsgBuffer() *maxMsgBuffer { + return bufPool.Get().(*maxMsgBuffer) +} diff --git a/control/controlbase/handshake.go b/control/controlbase/handshake.go index 937969a30..765a4620b 100644 --- a/control/controlbase/handshake.go +++ b/control/controlbase/handshake.go @@ -1,494 +1,494 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlbase - -import ( - "context" - "crypto/cipher" - "encoding/binary" - "errors" - "fmt" - "hash" - "io" - "net" - "strconv" - "time" - - "go4.org/mem" - "golang.org/x/crypto/blake2s" - chp "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/hkdf" - "tailscale.com/types/key" -) - -const ( - // protocolName is the name of the specific instantiation of Noise - // that the control protocol uses. This string's value is fixed by - // the Noise spec, and shouldn't be changed unless we're updating - // the control protocol to use a different Noise instance. - protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s" - // protocolVersion is the version of the control protocol that - // Client will use when initiating a handshake. - //protocolVersion uint16 = 1 - // protocolVersionPrefix is the name portion of the protocol - // name+version string that gets mixed into the handshake as a - // prologue. - // - // This mixing verifies that both clients agree that they're - // executing the control protocol at a specific version that - // matches the advertised version in the cleartext packet header. - protocolVersionPrefix = "Tailscale Control Protocol v" - invalidNonce = ^uint64(0) -) - -func protocolVersionPrologue(version uint16) []byte { - ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers. - ret = append(ret, protocolVersionPrefix...) - return strconv.AppendUint(ret, uint64(version), 10) -} - -// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn -// is assumed to have already sent the client>server handshake -// initiation message. -type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error) - -// ClientDeferred initiates a control client handshake, returning the -// initial message to send to the server and a continuation to -// finalize the handshake. -// -// ClientDeferred is split in this way for RTT reduction: we run this -// protocol after negotiating a protocol switch from HTTP/HTTPS. If we -// completely serialized the negotiation followed by the handshake, -// we'd pay an extra RTT to transmit the handshake initiation after -// protocol switching. By splitting the handshake into an initial -// message and a continuation, we can embed the handshake initiation -// into the HTTP protocol switching request and avoid a bit of delay. -func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) { - var s symmetricState - s.Initialize() - - // prologue - s.MixHash(protocolVersionPrologue(protocolVersion)) - - // <- s - // ... - s.MixHash(controlKey.UntypedBytes()) - - // -> e, es, s, ss - init := mkInitiationMessage(protocolVersion) - machineEphemeral := key.NewMachine() - machineEphemeralPub := machineEphemeral.Public() - copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes()) - s.MixHash(machineEphemeralPub.UntypedBytes()) - cipher, err := s.MixDH(machineEphemeral, controlKey) - if err != nil { - return nil, nil, fmt.Errorf("computing es: %w", err) - } - machineKeyPub := machineKey.Public() - s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes()) - cipher, err = s.MixDH(machineKey, controlKey) - if err != nil { - return nil, nil, fmt.Errorf("computing ss: %w", err) - } - s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload - - cont := func(ctx context.Context, conn net.Conn) (*Conn, error) { - return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion) - } - return init[:], cont, nil -} - -// Client wraps ClientDeferred and immediately invokes the returned -// continuation with conn. -// -// This is a helper for when you don't need the fancy -// continuation-style handshake, and just want to synchronously -// upgrade a net.Conn to a secure transport. -func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { - init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion) - if err != nil { - return nil, err - } - if _, err := conn.Write(init); err != nil { - return nil, err - } - return cont(ctx, conn) -} - -func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { - // No matter what, this function can only run once per s. Ensure - // attempted reuse causes a panic. - defer func() { - s.finished = true - }() - - if deadline, ok := ctx.Deadline(); ok { - if err := conn.SetDeadline(deadline); err != nil { - return nil, fmt.Errorf("setting conn deadline: %w", err) - } - defer func() { - conn.SetDeadline(time.Time{}) - }() - } - - // Read in the payload and look for errors/protocol violations from the server. - var resp responseMessage - if _, err := io.ReadFull(conn, resp.Header()); err != nil { - return nil, fmt.Errorf("reading response header: %w", err) - } - if resp.Type() != msgTypeResponse { - if resp.Type() != msgTypeError { - return nil, fmt.Errorf("unexpected response message type %d", resp.Type()) - } - msg := make([]byte, resp.Length()) - if _, err := io.ReadFull(conn, msg); err != nil { - return nil, err - } - return nil, fmt.Errorf("server error: %q", msg) - } - if resp.Length() != len(resp.Payload()) { - return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length()) - } - if _, err := io.ReadFull(conn, resp.Payload()); err != nil { - return nil, err - } - - // <- e, ee, se - controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub())) - s.MixHash(controlEphemeralPub.UntypedBytes()) - if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { - return nil, fmt.Errorf("computing ee: %w", err) - } - cipher, err := s.MixDH(machineKey, controlEphemeralPub) - if err != nil { - return nil, fmt.Errorf("computing se: %w", err) - } - if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil { - return nil, fmt.Errorf("decrypting payload: %w", err) - } - - c1, c2, err := s.Split() - if err != nil { - return nil, fmt.Errorf("finalizing handshake: %w", err) - } - - c := &Conn{ - conn: conn, - version: protocolVersion, - peer: controlKey, - handshakeHash: s.h, - tx: txState{ - cipher: c1, - }, - rx: rxState{ - cipher: c2, - }, - } - return c, nil -} - -// Server initiates a control server handshake, returning the resulting -// control connection. -// -// optionalInit can be the client's initial handshake message as -// returned by ClientDeferred, or nil in which case the initial -// message is read from conn. -// -// The context deadline, if any, covers the entire handshaking -// process. -func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) { - if deadline, ok := ctx.Deadline(); ok { - if err := conn.SetDeadline(deadline); err != nil { - return nil, fmt.Errorf("setting conn deadline: %w", err) - } - defer func() { - conn.SetDeadline(time.Time{}) - }() - } - - // Deliberately does not support formatting, so that we don't echo - // attacker-controlled input back to them. - sendErr := func(msg string) error { - if len(msg) >= 1<<16 { - msg = msg[:1<<16] - } - var hdr [headerLen]byte - hdr[0] = msgTypeError - binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg))) - if _, err := conn.Write(hdr[:]); err != nil { - return fmt.Errorf("sending %q error to client: %w", msg, err) - } - if _, err := io.WriteString(conn, msg); err != nil { - return fmt.Errorf("sending %q error to client: %w", msg, err) - } - return fmt.Errorf("refused client handshake: %q", msg) - } - - var s symmetricState - s.Initialize() - - var init initiationMessage - if optionalInit != nil { - if len(optionalInit) != len(init) { - return nil, sendErr("wrong handshake initiation size") - } - copy(init[:], optionalInit) - } else if _, err := io.ReadFull(conn, init.Header()); err != nil { - return nil, err - } - // Just a rename to make it more obvious what the value is. In the - // current implementation we don't need to block any protocol - // versions at this layer, it's safe to let the handshake proceed - // and then let the caller make decisions based on the agreed-upon - // protocol version. - clientVersion := init.Version() - if init.Type() != msgTypeInitiation { - return nil, sendErr("unexpected handshake message type") - } - if init.Length() != len(init.Payload()) { - return nil, sendErr("wrong handshake initiation length") - } - // if optionalInit was provided, we have the payload already. - if optionalInit == nil { - if _, err := io.ReadFull(conn, init.Payload()); err != nil { - return nil, err - } - } - - // prologue. Can only do this once we at least think the client is - // handshaking using a supported version. - s.MixHash(protocolVersionPrologue(clientVersion)) - - // <- s - // ... - controlKeyPub := controlKey.Public() - s.MixHash(controlKeyPub.UntypedBytes()) - - // -> e, es, s, ss - machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub())) - s.MixHash(machineEphemeralPub.UntypedBytes()) - cipher, err := s.MixDH(controlKey, machineEphemeralPub) - if err != nil { - return nil, fmt.Errorf("computing es: %w", err) - } - var machineKeyBytes [32]byte - if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil { - return nil, fmt.Errorf("decrypting machine key: %w", err) - } - machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:])) - cipher, err = s.MixDH(controlKey, machineKey) - if err != nil { - return nil, fmt.Errorf("computing ss: %w", err) - } - if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil { - return nil, fmt.Errorf("decrypting initiation tag: %w", err) - } - - // <- e, ee, se - resp := mkResponseMessage() - controlEphemeral := key.NewMachine() - controlEphemeralPub := controlEphemeral.Public() - copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes()) - s.MixHash(controlEphemeralPub.UntypedBytes()) - if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil { - return nil, fmt.Errorf("computing ee: %w", err) - } - cipher, err = s.MixDH(controlEphemeral, machineKey) - if err != nil { - return nil, fmt.Errorf("computing se: %w", err) - } - s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload - - c1, c2, err := s.Split() - if err != nil { - return nil, fmt.Errorf("finalizing handshake: %w", err) - } - - if _, err := conn.Write(resp[:]); err != nil { - return nil, err - } - - c := &Conn{ - conn: conn, - version: clientVersion, - peer: machineKey, - handshakeHash: s.h, - tx: txState{ - cipher: c2, - }, - rx: rxState{ - cipher: c1, - }, - } - return c, nil -} - -// symmetricState contains the state of an in-flight handshake. -type symmetricState struct { - finished bool - - h [blake2s.Size]byte // hash of currently-processed handshake state - ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake -} - -func (s *symmetricState) checkFinished() { - if s.finished { - panic("attempted to use symmetricState after Split was called") - } -} - -// Initialize sets s to the initial handshake state, prior to -// processing any handshake messages. -func (s *symmetricState) Initialize() { - s.checkFinished() - s.h = blake2s.Sum256([]byte(protocolName)) - s.ck = s.h -} - -// MixHash updates s.h to be BLAKE2s(s.h || data), where || is -// concatenation. -func (s *symmetricState) MixHash(data []byte) { - s.checkFinished() - h := newBLAKE2s() - h.Write(s.h[:]) - h.Write(data) - h.Sum(s.h[:0]) -} - -// MixDH updates s.ck with the result of X25519(priv, pub) and returns -// a singleUseCHP that can be used to encrypt or decrypt handshake -// data. -// -// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing -// it as a single function allows for strongly-typed arguments that -// reduce the risk of error in the caller (e.g. invoking X25519 with -// two private keys, or two public keys), and thus producing the wrong -// calculation. -func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) { - s.checkFinished() - keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes()) - if err != nil { - return nil, fmt.Errorf("computing X25519: %w", err) - } - - r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil) - if _, err := io.ReadFull(r, s.ck[:]); err != nil { - return nil, fmt.Errorf("extracting ck: %w", err) - } - var k [chp.KeySize]byte - if _, err := io.ReadFull(r, k[:]); err != nil { - return nil, fmt.Errorf("extracting k: %w", err) - } - return newSingleUseCHP(k), nil -} - -// EncryptAndHash encrypts plaintext into ciphertext (which must be -// the correct size to hold the encrypted plaintext) using cipher, -// mixes the ciphertext into s.h, and returns the ciphertext. -func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) { - s.checkFinished() - if len(ciphertext) != len(plaintext)+chp.Overhead { - panic("ciphertext is wrong size for given plaintext") - } - ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:]) - s.MixHash(ret) -} - -// DecryptAndHash decrypts the given ciphertext into plaintext (which -// must be the correct size to hold the decrypted ciphertext) using -// cipher. If decryption is successful, it mixes the ciphertext into -// s.h. -func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error { - s.checkFinished() - if len(ciphertext) != len(plaintext)+chp.Overhead { - return errors.New("plaintext is wrong size for given ciphertext") - } - if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil { - return err - } - s.MixHash(ciphertext) - return nil -} - -// Split returns two ChaCha20Poly1305 ciphers with keys derived from -// the current handshake state. Methods on s cannot be used again -// after calling Split. -func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) { - s.finished = true - - var k1, k2 [chp.KeySize]byte - r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil) - if _, err := io.ReadFull(r, k1[:]); err != nil { - return nil, nil, fmt.Errorf("extracting k1: %w", err) - } - if _, err := io.ReadFull(r, k2[:]); err != nil { - return nil, nil, fmt.Errorf("extracting k2: %w", err) - } - c1, err = chp.New(k1[:]) - if err != nil { - return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err) - } - c2, err = chp.New(k2[:]) - if err != nil { - return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err) - } - return c1, c2, nil -} - -// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on -// error. -func newBLAKE2s() hash.Hash { - h, err := blake2s.New256(nil) - if err != nil { - // Should never happen, errors only happen when using BLAKE2s - // in MAC mode with a key. - panic(err) - } - return h -} - -// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or -// panics on error. -func newCHP(key [chp.KeySize]byte) cipher.AEAD { - aead, err := chp.New(key[:]) - if err != nil { - // Can only happen if we passed a key of the wrong length. The - // function signature prevents that. - panic(err) - } - return aead -} - -// singleUseCHP is an instance of ChaCha20Poly1305 that can be used -// only once, either for encrypting or decrypting, but not both. The -// chosen operation is always executed with an all-zeros -// nonce. Subsequent calls to either Seal or Open panic. -type singleUseCHP struct { - c cipher.AEAD -} - -func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP { - return &singleUseCHP{newCHP(key)} -} - -func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte { - if c.c == nil { - panic("Attempted reuse of singleUseAEAD") - } - cipher := c.c - c.c = nil - var nonce [chp.NonceSize]byte - return cipher.Seal(dst, nonce[:], plaintext, additionalData) -} - -func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) { - if c.c == nil { - panic("Attempted reuse of singleUseAEAD") - } - cipher := c.c - c.c = nil - var nonce [chp.NonceSize]byte - return cipher.Open(dst, nonce[:], ciphertext, additionalData) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlbase + +import ( + "context" + "crypto/cipher" + "encoding/binary" + "errors" + "fmt" + "hash" + "io" + "net" + "strconv" + "time" + + "go4.org/mem" + "golang.org/x/crypto/blake2s" + chp "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/hkdf" + "tailscale.com/types/key" +) + +const ( + // protocolName is the name of the specific instantiation of Noise + // that the control protocol uses. This string's value is fixed by + // the Noise spec, and shouldn't be changed unless we're updating + // the control protocol to use a different Noise instance. + protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s" + // protocolVersion is the version of the control protocol that + // Client will use when initiating a handshake. + //protocolVersion uint16 = 1 + // protocolVersionPrefix is the name portion of the protocol + // name+version string that gets mixed into the handshake as a + // prologue. + // + // This mixing verifies that both clients agree that they're + // executing the control protocol at a specific version that + // matches the advertised version in the cleartext packet header. + protocolVersionPrefix = "Tailscale Control Protocol v" + invalidNonce = ^uint64(0) +) + +func protocolVersionPrologue(version uint16) []byte { + ret := make([]byte, 0, len(protocolVersionPrefix)+5) // 5 bytes is enough to encode all possible version numbers. + ret = append(ret, protocolVersionPrefix...) + return strconv.AppendUint(ret, uint64(version), 10) +} + +// HandshakeContinuation upgrades a net.Conn to a Conn. The net.Conn +// is assumed to have already sent the client>server handshake +// initiation message. +type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error) + +// ClientDeferred initiates a control client handshake, returning the +// initial message to send to the server and a continuation to +// finalize the handshake. +// +// ClientDeferred is split in this way for RTT reduction: we run this +// protocol after negotiating a protocol switch from HTTP/HTTPS. If we +// completely serialized the negotiation followed by the handshake, +// we'd pay an extra RTT to transmit the handshake initiation after +// protocol switching. By splitting the handshake into an initial +// message and a continuation, we can embed the handshake initiation +// into the HTTP protocol switching request and avoid a bit of delay. +func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) { + var s symmetricState + s.Initialize() + + // prologue + s.MixHash(protocolVersionPrologue(protocolVersion)) + + // <- s + // ... + s.MixHash(controlKey.UntypedBytes()) + + // -> e, es, s, ss + init := mkInitiationMessage(protocolVersion) + machineEphemeral := key.NewMachine() + machineEphemeralPub := machineEphemeral.Public() + copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes()) + s.MixHash(machineEphemeralPub.UntypedBytes()) + cipher, err := s.MixDH(machineEphemeral, controlKey) + if err != nil { + return nil, nil, fmt.Errorf("computing es: %w", err) + } + machineKeyPub := machineKey.Public() + s.EncryptAndHash(cipher, init.MachinePub(), machineKeyPub.UntypedBytes()) + cipher, err = s.MixDH(machineKey, controlKey) + if err != nil { + return nil, nil, fmt.Errorf("computing ss: %w", err) + } + s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload + + cont := func(ctx context.Context, conn net.Conn) (*Conn, error) { + return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion) + } + return init[:], cont, nil +} + +// Client wraps ClientDeferred and immediately invokes the returned +// continuation with conn. +// +// This is a helper for when you don't need the fancy +// continuation-style handshake, and just want to synchronously +// upgrade a net.Conn to a secure transport. +func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { + init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion) + if err != nil { + return nil, err + } + if _, err := conn.Write(init); err != nil { + return nil, err + } + return cont(ctx, conn) +} + +func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) { + // No matter what, this function can only run once per s. Ensure + // attempted reuse causes a panic. + defer func() { + s.finished = true + }() + + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetDeadline(deadline); err != nil { + return nil, fmt.Errorf("setting conn deadline: %w", err) + } + defer func() { + conn.SetDeadline(time.Time{}) + }() + } + + // Read in the payload and look for errors/protocol violations from the server. + var resp responseMessage + if _, err := io.ReadFull(conn, resp.Header()); err != nil { + return nil, fmt.Errorf("reading response header: %w", err) + } + if resp.Type() != msgTypeResponse { + if resp.Type() != msgTypeError { + return nil, fmt.Errorf("unexpected response message type %d", resp.Type()) + } + msg := make([]byte, resp.Length()) + if _, err := io.ReadFull(conn, msg); err != nil { + return nil, err + } + return nil, fmt.Errorf("server error: %q", msg) + } + if resp.Length() != len(resp.Payload()) { + return nil, fmt.Errorf("wrong length %d received for handshake response", resp.Length()) + } + if _, err := io.ReadFull(conn, resp.Payload()); err != nil { + return nil, err + } + + // <- e, ee, se + controlEphemeralPub := key.MachinePublicFromRaw32(mem.B(resp.EphemeralPub())) + s.MixHash(controlEphemeralPub.UntypedBytes()) + if _, err := s.MixDH(machineEphemeral, controlEphemeralPub); err != nil { + return nil, fmt.Errorf("computing ee: %w", err) + } + cipher, err := s.MixDH(machineKey, controlEphemeralPub) + if err != nil { + return nil, fmt.Errorf("computing se: %w", err) + } + if err := s.DecryptAndHash(cipher, nil, resp.Tag()); err != nil { + return nil, fmt.Errorf("decrypting payload: %w", err) + } + + c1, c2, err := s.Split() + if err != nil { + return nil, fmt.Errorf("finalizing handshake: %w", err) + } + + c := &Conn{ + conn: conn, + version: protocolVersion, + peer: controlKey, + handshakeHash: s.h, + tx: txState{ + cipher: c1, + }, + rx: rxState{ + cipher: c2, + }, + } + return c, nil +} + +// Server initiates a control server handshake, returning the resulting +// control connection. +// +// optionalInit can be the client's initial handshake message as +// returned by ClientDeferred, or nil in which case the initial +// message is read from conn. +// +// The context deadline, if any, covers the entire handshaking +// process. +func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) { + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetDeadline(deadline); err != nil { + return nil, fmt.Errorf("setting conn deadline: %w", err) + } + defer func() { + conn.SetDeadline(time.Time{}) + }() + } + + // Deliberately does not support formatting, so that we don't echo + // attacker-controlled input back to them. + sendErr := func(msg string) error { + if len(msg) >= 1<<16 { + msg = msg[:1<<16] + } + var hdr [headerLen]byte + hdr[0] = msgTypeError + binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg))) + if _, err := conn.Write(hdr[:]); err != nil { + return fmt.Errorf("sending %q error to client: %w", msg, err) + } + if _, err := io.WriteString(conn, msg); err != nil { + return fmt.Errorf("sending %q error to client: %w", msg, err) + } + return fmt.Errorf("refused client handshake: %q", msg) + } + + var s symmetricState + s.Initialize() + + var init initiationMessage + if optionalInit != nil { + if len(optionalInit) != len(init) { + return nil, sendErr("wrong handshake initiation size") + } + copy(init[:], optionalInit) + } else if _, err := io.ReadFull(conn, init.Header()); err != nil { + return nil, err + } + // Just a rename to make it more obvious what the value is. In the + // current implementation we don't need to block any protocol + // versions at this layer, it's safe to let the handshake proceed + // and then let the caller make decisions based on the agreed-upon + // protocol version. + clientVersion := init.Version() + if init.Type() != msgTypeInitiation { + return nil, sendErr("unexpected handshake message type") + } + if init.Length() != len(init.Payload()) { + return nil, sendErr("wrong handshake initiation length") + } + // if optionalInit was provided, we have the payload already. + if optionalInit == nil { + if _, err := io.ReadFull(conn, init.Payload()); err != nil { + return nil, err + } + } + + // prologue. Can only do this once we at least think the client is + // handshaking using a supported version. + s.MixHash(protocolVersionPrologue(clientVersion)) + + // <- s + // ... + controlKeyPub := controlKey.Public() + s.MixHash(controlKeyPub.UntypedBytes()) + + // -> e, es, s, ss + machineEphemeralPub := key.MachinePublicFromRaw32(mem.B(init.EphemeralPub())) + s.MixHash(machineEphemeralPub.UntypedBytes()) + cipher, err := s.MixDH(controlKey, machineEphemeralPub) + if err != nil { + return nil, fmt.Errorf("computing es: %w", err) + } + var machineKeyBytes [32]byte + if err := s.DecryptAndHash(cipher, machineKeyBytes[:], init.MachinePub()); err != nil { + return nil, fmt.Errorf("decrypting machine key: %w", err) + } + machineKey := key.MachinePublicFromRaw32(mem.B(machineKeyBytes[:])) + cipher, err = s.MixDH(controlKey, machineKey) + if err != nil { + return nil, fmt.Errorf("computing ss: %w", err) + } + if err := s.DecryptAndHash(cipher, nil, init.Tag()); err != nil { + return nil, fmt.Errorf("decrypting initiation tag: %w", err) + } + + // <- e, ee, se + resp := mkResponseMessage() + controlEphemeral := key.NewMachine() + controlEphemeralPub := controlEphemeral.Public() + copy(resp.EphemeralPub(), controlEphemeralPub.UntypedBytes()) + s.MixHash(controlEphemeralPub.UntypedBytes()) + if _, err := s.MixDH(controlEphemeral, machineEphemeralPub); err != nil { + return nil, fmt.Errorf("computing ee: %w", err) + } + cipher, err = s.MixDH(controlEphemeral, machineKey) + if err != nil { + return nil, fmt.Errorf("computing se: %w", err) + } + s.EncryptAndHash(cipher, resp.Tag(), nil) // empty message payload + + c1, c2, err := s.Split() + if err != nil { + return nil, fmt.Errorf("finalizing handshake: %w", err) + } + + if _, err := conn.Write(resp[:]); err != nil { + return nil, err + } + + c := &Conn{ + conn: conn, + version: clientVersion, + peer: machineKey, + handshakeHash: s.h, + tx: txState{ + cipher: c2, + }, + rx: rxState{ + cipher: c1, + }, + } + return c, nil +} + +// symmetricState contains the state of an in-flight handshake. +type symmetricState struct { + finished bool + + h [blake2s.Size]byte // hash of currently-processed handshake state + ck [blake2s.Size]byte // chaining key used to construct session keys at the end of the handshake +} + +func (s *symmetricState) checkFinished() { + if s.finished { + panic("attempted to use symmetricState after Split was called") + } +} + +// Initialize sets s to the initial handshake state, prior to +// processing any handshake messages. +func (s *symmetricState) Initialize() { + s.checkFinished() + s.h = blake2s.Sum256([]byte(protocolName)) + s.ck = s.h +} + +// MixHash updates s.h to be BLAKE2s(s.h || data), where || is +// concatenation. +func (s *symmetricState) MixHash(data []byte) { + s.checkFinished() + h := newBLAKE2s() + h.Write(s.h[:]) + h.Write(data) + h.Sum(s.h[:0]) +} + +// MixDH updates s.ck with the result of X25519(priv, pub) and returns +// a singleUseCHP that can be used to encrypt or decrypt handshake +// data. +// +// MixDH corresponds to MixKey(X25519(...))) in the spec. Implementing +// it as a single function allows for strongly-typed arguments that +// reduce the risk of error in the caller (e.g. invoking X25519 with +// two private keys, or two public keys), and thus producing the wrong +// calculation. +func (s *symmetricState) MixDH(priv key.MachinePrivate, pub key.MachinePublic) (*singleUseCHP, error) { + s.checkFinished() + keyData, err := curve25519.X25519(priv.UntypedBytes(), pub.UntypedBytes()) + if err != nil { + return nil, fmt.Errorf("computing X25519: %w", err) + } + + r := hkdf.New(newBLAKE2s, keyData, s.ck[:], nil) + if _, err := io.ReadFull(r, s.ck[:]); err != nil { + return nil, fmt.Errorf("extracting ck: %w", err) + } + var k [chp.KeySize]byte + if _, err := io.ReadFull(r, k[:]); err != nil { + return nil, fmt.Errorf("extracting k: %w", err) + } + return newSingleUseCHP(k), nil +} + +// EncryptAndHash encrypts plaintext into ciphertext (which must be +// the correct size to hold the encrypted plaintext) using cipher, +// mixes the ciphertext into s.h, and returns the ciphertext. +func (s *symmetricState) EncryptAndHash(cipher *singleUseCHP, ciphertext, plaintext []byte) { + s.checkFinished() + if len(ciphertext) != len(plaintext)+chp.Overhead { + panic("ciphertext is wrong size for given plaintext") + } + ret := cipher.Seal(ciphertext[:0], plaintext, s.h[:]) + s.MixHash(ret) +} + +// DecryptAndHash decrypts the given ciphertext into plaintext (which +// must be the correct size to hold the decrypted ciphertext) using +// cipher. If decryption is successful, it mixes the ciphertext into +// s.h. +func (s *symmetricState) DecryptAndHash(cipher *singleUseCHP, plaintext, ciphertext []byte) error { + s.checkFinished() + if len(ciphertext) != len(plaintext)+chp.Overhead { + return errors.New("plaintext is wrong size for given ciphertext") + } + if _, err := cipher.Open(plaintext[:0], ciphertext, s.h[:]); err != nil { + return err + } + s.MixHash(ciphertext) + return nil +} + +// Split returns two ChaCha20Poly1305 ciphers with keys derived from +// the current handshake state. Methods on s cannot be used again +// after calling Split. +func (s *symmetricState) Split() (c1, c2 cipher.AEAD, err error) { + s.finished = true + + var k1, k2 [chp.KeySize]byte + r := hkdf.New(newBLAKE2s, nil, s.ck[:], nil) + if _, err := io.ReadFull(r, k1[:]); err != nil { + return nil, nil, fmt.Errorf("extracting k1: %w", err) + } + if _, err := io.ReadFull(r, k2[:]); err != nil { + return nil, nil, fmt.Errorf("extracting k2: %w", err) + } + c1, err = chp.New(k1[:]) + if err != nil { + return nil, nil, fmt.Errorf("constructing AEAD c1: %w", err) + } + c2, err = chp.New(k2[:]) + if err != nil { + return nil, nil, fmt.Errorf("constructing AEAD c2: %w", err) + } + return c1, c2, nil +} + +// newBLAKE2s returns a hash.Hash implementing BLAKE2s, or panics on +// error. +func newBLAKE2s() hash.Hash { + h, err := blake2s.New256(nil) + if err != nil { + // Should never happen, errors only happen when using BLAKE2s + // in MAC mode with a key. + panic(err) + } + return h +} + +// newCHP returns a cipher.AEAD implementing ChaCha20Poly1305, or +// panics on error. +func newCHP(key [chp.KeySize]byte) cipher.AEAD { + aead, err := chp.New(key[:]) + if err != nil { + // Can only happen if we passed a key of the wrong length. The + // function signature prevents that. + panic(err) + } + return aead +} + +// singleUseCHP is an instance of ChaCha20Poly1305 that can be used +// only once, either for encrypting or decrypting, but not both. The +// chosen operation is always executed with an all-zeros +// nonce. Subsequent calls to either Seal or Open panic. +type singleUseCHP struct { + c cipher.AEAD +} + +func newSingleUseCHP(key [chp.KeySize]byte) *singleUseCHP { + return &singleUseCHP{newCHP(key)} +} + +func (c *singleUseCHP) Seal(dst, plaintext, additionalData []byte) []byte { + if c.c == nil { + panic("Attempted reuse of singleUseAEAD") + } + cipher := c.c + c.c = nil + var nonce [chp.NonceSize]byte + return cipher.Seal(dst, nonce[:], plaintext, additionalData) +} + +func (c *singleUseCHP) Open(dst, ciphertext, additionalData []byte) ([]byte, error) { + if c.c == nil { + panic("Attempted reuse of singleUseAEAD") + } + cipher := c.c + c.c = nil + var nonce [chp.NonceSize]byte + return cipher.Open(dst, nonce[:], ciphertext, additionalData) +} diff --git a/control/controlbase/interop_test.go b/control/controlbase/interop_test.go index d11c04149..c41fbf4dd 100644 --- a/control/controlbase/interop_test.go +++ b/control/controlbase/interop_test.go @@ -1,256 +1,256 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlbase - -import ( - "context" - "encoding/binary" - "errors" - "io" - "net" - "testing" - - "tailscale.com/net/memnet" - "tailscale.com/types/key" -) - -// Can a reference Noise IK client talk to our server? -func TestInteropClient(t *testing.T) { - var ( - s1, s2 = memnet.NewConn("noise", 128000) - controlKey = key.NewMachine() - machineKey = key.NewMachine() - serverErr = make(chan error, 2) - serverBytes = make(chan []byte, 1) - c2s = "client>server" - s2c = "server>client" - ) - - go func() { - server, err := Server(context.Background(), s2, controlKey, nil) - serverErr <- err - if err != nil { - return - } - var buf [1024]byte - _, err = io.ReadFull(server, buf[:len(c2s)]) - serverBytes <- buf[:len(c2s)] - if err != nil { - serverErr <- err - return - } - _, err = server.Write([]byte(s2c)) - serverErr <- err - }() - - gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s)) - if err != nil { - t.Fatalf("failed client interop: %v", err) - } - if string(gotS2C) != s2c { - t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c) - } - - if err := <-serverErr; err != nil { - t.Fatalf("server handshake failed: %v", err) - } - if err := <-serverErr; err != nil { - t.Fatalf("server read/write failed: %v", err) - } - if got := string(<-serverBytes); got != c2s { - t.Fatalf("server received %q, want %q", got, c2s) - } -} - -// Can our client talk to a reference Noise IK server? -func TestInteropServer(t *testing.T) { - var ( - s1, s2 = memnet.NewConn("noise", 128000) - controlKey = key.NewMachine() - machineKey = key.NewMachine() - clientErr = make(chan error, 2) - clientBytes = make(chan []byte, 1) - c2s = "client>server" - s2c = "server>client" - ) - - go func() { - client, err := Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion) - clientErr <- err - if err != nil { - return - } - _, err = client.Write([]byte(c2s)) - if err != nil { - clientErr <- err - return - } - var buf [1024]byte - _, err = io.ReadFull(client, buf[:len(s2c)]) - clientBytes <- buf[:len(s2c)] - clientErr <- err - }() - - gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c)) - if err != nil { - t.Fatalf("failed server interop: %v", err) - } - if string(gotC2S) != c2s { - t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s) - } - - if err := <-clientErr; err != nil { - t.Fatalf("client handshake failed: %v", err) - } - if err := <-clientErr; err != nil { - t.Fatalf("client read/write failed: %v", err) - } - if got := string(<-clientBytes); got != s2c { - t.Fatalf("client received %q, want %q", got, s2c) - } -} - -// noiseExplorerClient uses the Noise Explorer implementation of Noise -// IK to handshake as a Noise client on conn, transmit payload, and -// read+return a payload from the peer. -func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey key.MachinePrivate, payload []byte) ([]byte, error) { - var mk keypair - copy(mk.private_key[:], machineKey.UntypedBytes()) - copy(mk.public_key[:], machineKey.Public().UntypedBytes()) - var peerKey [32]byte - copy(peerKey[:], controlKey.UntypedBytes()) - session := InitSession(true, protocolVersionPrologue(testProtocolVersion), mk, peerKey) - - _, msg1 := SendMessage(&session, nil) - var hdr [initiationHeaderLen]byte - binary.BigEndian.PutUint16(hdr[:2], testProtocolVersion) - hdr[2] = msgTypeInitiation - binary.BigEndian.PutUint16(hdr[3:5], 96) - if _, err := conn.Write(hdr[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg1.ne[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg1.ns); err != nil { - return nil, err - } - if _, err := conn.Write(msg1.ciphertext); err != nil { - return nil, err - } - - var buf [1024]byte - if _, err := io.ReadFull(conn, buf[:51]); err != nil { - return nil, err - } - // ignore the header for this test, we're only checking the noise - // implementation. - msg2 := messagebuffer{ - ciphertext: buf[35:51], - } - copy(msg2.ne[:], buf[3:35]) - _, p, valid := RecvMessage(&session, &msg2) - if !valid { - return nil, errors.New("handshake failed") - } - if len(p) != 0 { - return nil, errors.New("non-empty payload") - } - - _, msg3 := SendMessage(&session, payload) - hdr[0] = msgTypeRecord - binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg3.ciphertext))) - if _, err := conn.Write(hdr[:3]); err != nil { - return nil, err - } - if _, err := conn.Write(msg3.ciphertext); err != nil { - return nil, err - } - - if _, err := io.ReadFull(conn, buf[:3]); err != nil { - return nil, err - } - // Ignore all of the header except the payload length - plen := int(binary.BigEndian.Uint16(buf[1:3])) - if _, err := io.ReadFull(conn, buf[:plen]); err != nil { - return nil, err - } - - msg4 := messagebuffer{ - ciphertext: buf[:plen], - } - _, p, valid = RecvMessage(&session, &msg4) - if !valid { - return nil, errors.New("transport message decryption failed") - } - - return p, nil -} - -func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachineKey key.MachinePublic, payload []byte) ([]byte, error) { - var mk keypair - copy(mk.private_key[:], controlKey.UntypedBytes()) - copy(mk.public_key[:], controlKey.Public().UntypedBytes()) - session := InitSession(false, protocolVersionPrologue(testProtocolVersion), mk, [32]byte{}) - - var buf [1024]byte - if _, err := io.ReadFull(conn, buf[:101]); err != nil { - return nil, err - } - // Ignore the header, we're just checking the noise implementation. - msg1 := messagebuffer{ - ns: buf[37:85], - ciphertext: buf[85:101], - } - copy(msg1.ne[:], buf[5:37]) - _, p, valid := RecvMessage(&session, &msg1) - if !valid { - return nil, errors.New("handshake failed") - } - if len(p) != 0 { - return nil, errors.New("non-empty payload") - } - - _, msg2 := SendMessage(&session, nil) - var hdr [headerLen]byte - hdr[0] = msgTypeResponse - binary.BigEndian.PutUint16(hdr[1:3], 48) - if _, err := conn.Write(hdr[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg2.ne[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg2.ciphertext[:]); err != nil { - return nil, err - } - - if _, err := io.ReadFull(conn, buf[:3]); err != nil { - return nil, err - } - plen := int(binary.BigEndian.Uint16(buf[1:3])) - if _, err := io.ReadFull(conn, buf[:plen]); err != nil { - return nil, err - } - - msg3 := messagebuffer{ - ciphertext: buf[:plen], - } - _, p, valid = RecvMessage(&session, &msg3) - if !valid { - return nil, errors.New("transport message decryption failed") - } - - _, msg4 := SendMessage(&session, payload) - hdr[0] = msgTypeRecord - binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg4.ciphertext))) - if _, err := conn.Write(hdr[:]); err != nil { - return nil, err - } - if _, err := conn.Write(msg4.ciphertext); err != nil { - return nil, err - } - - return p, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlbase + +import ( + "context" + "encoding/binary" + "errors" + "io" + "net" + "testing" + + "tailscale.com/net/memnet" + "tailscale.com/types/key" +) + +// Can a reference Noise IK client talk to our server? +func TestInteropClient(t *testing.T) { + var ( + s1, s2 = memnet.NewConn("noise", 128000) + controlKey = key.NewMachine() + machineKey = key.NewMachine() + serverErr = make(chan error, 2) + serverBytes = make(chan []byte, 1) + c2s = "client>server" + s2c = "server>client" + ) + + go func() { + server, err := Server(context.Background(), s2, controlKey, nil) + serverErr <- err + if err != nil { + return + } + var buf [1024]byte + _, err = io.ReadFull(server, buf[:len(c2s)]) + serverBytes <- buf[:len(c2s)] + if err != nil { + serverErr <- err + return + } + _, err = server.Write([]byte(s2c)) + serverErr <- err + }() + + gotS2C, err := noiseExplorerClient(s1, controlKey.Public(), machineKey, []byte(c2s)) + if err != nil { + t.Fatalf("failed client interop: %v", err) + } + if string(gotS2C) != s2c { + t.Fatalf("server sent unexpected data %q, want %q", string(gotS2C), s2c) + } + + if err := <-serverErr; err != nil { + t.Fatalf("server handshake failed: %v", err) + } + if err := <-serverErr; err != nil { + t.Fatalf("server read/write failed: %v", err) + } + if got := string(<-serverBytes); got != c2s { + t.Fatalf("server received %q, want %q", got, c2s) + } +} + +// Can our client talk to a reference Noise IK server? +func TestInteropServer(t *testing.T) { + var ( + s1, s2 = memnet.NewConn("noise", 128000) + controlKey = key.NewMachine() + machineKey = key.NewMachine() + clientErr = make(chan error, 2) + clientBytes = make(chan []byte, 1) + c2s = "client>server" + s2c = "server>client" + ) + + go func() { + client, err := Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion) + clientErr <- err + if err != nil { + return + } + _, err = client.Write([]byte(c2s)) + if err != nil { + clientErr <- err + return + } + var buf [1024]byte + _, err = io.ReadFull(client, buf[:len(s2c)]) + clientBytes <- buf[:len(s2c)] + clientErr <- err + }() + + gotC2S, err := noiseExplorerServer(s2, controlKey, machineKey.Public(), []byte(s2c)) + if err != nil { + t.Fatalf("failed server interop: %v", err) + } + if string(gotC2S) != c2s { + t.Fatalf("server sent unexpected data %q, want %q", string(gotC2S), c2s) + } + + if err := <-clientErr; err != nil { + t.Fatalf("client handshake failed: %v", err) + } + if err := <-clientErr; err != nil { + t.Fatalf("client read/write failed: %v", err) + } + if got := string(<-clientBytes); got != s2c { + t.Fatalf("client received %q, want %q", got, s2c) + } +} + +// noiseExplorerClient uses the Noise Explorer implementation of Noise +// IK to handshake as a Noise client on conn, transmit payload, and +// read+return a payload from the peer. +func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey key.MachinePrivate, payload []byte) ([]byte, error) { + var mk keypair + copy(mk.private_key[:], machineKey.UntypedBytes()) + copy(mk.public_key[:], machineKey.Public().UntypedBytes()) + var peerKey [32]byte + copy(peerKey[:], controlKey.UntypedBytes()) + session := InitSession(true, protocolVersionPrologue(testProtocolVersion), mk, peerKey) + + _, msg1 := SendMessage(&session, nil) + var hdr [initiationHeaderLen]byte + binary.BigEndian.PutUint16(hdr[:2], testProtocolVersion) + hdr[2] = msgTypeInitiation + binary.BigEndian.PutUint16(hdr[3:5], 96) + if _, err := conn.Write(hdr[:]); err != nil { + return nil, err + } + if _, err := conn.Write(msg1.ne[:]); err != nil { + return nil, err + } + if _, err := conn.Write(msg1.ns); err != nil { + return nil, err + } + if _, err := conn.Write(msg1.ciphertext); err != nil { + return nil, err + } + + var buf [1024]byte + if _, err := io.ReadFull(conn, buf[:51]); err != nil { + return nil, err + } + // ignore the header for this test, we're only checking the noise + // implementation. + msg2 := messagebuffer{ + ciphertext: buf[35:51], + } + copy(msg2.ne[:], buf[3:35]) + _, p, valid := RecvMessage(&session, &msg2) + if !valid { + return nil, errors.New("handshake failed") + } + if len(p) != 0 { + return nil, errors.New("non-empty payload") + } + + _, msg3 := SendMessage(&session, payload) + hdr[0] = msgTypeRecord + binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg3.ciphertext))) + if _, err := conn.Write(hdr[:3]); err != nil { + return nil, err + } + if _, err := conn.Write(msg3.ciphertext); err != nil { + return nil, err + } + + if _, err := io.ReadFull(conn, buf[:3]); err != nil { + return nil, err + } + // Ignore all of the header except the payload length + plen := int(binary.BigEndian.Uint16(buf[1:3])) + if _, err := io.ReadFull(conn, buf[:plen]); err != nil { + return nil, err + } + + msg4 := messagebuffer{ + ciphertext: buf[:plen], + } + _, p, valid = RecvMessage(&session, &msg4) + if !valid { + return nil, errors.New("transport message decryption failed") + } + + return p, nil +} + +func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachineKey key.MachinePublic, payload []byte) ([]byte, error) { + var mk keypair + copy(mk.private_key[:], controlKey.UntypedBytes()) + copy(mk.public_key[:], controlKey.Public().UntypedBytes()) + session := InitSession(false, protocolVersionPrologue(testProtocolVersion), mk, [32]byte{}) + + var buf [1024]byte + if _, err := io.ReadFull(conn, buf[:101]); err != nil { + return nil, err + } + // Ignore the header, we're just checking the noise implementation. + msg1 := messagebuffer{ + ns: buf[37:85], + ciphertext: buf[85:101], + } + copy(msg1.ne[:], buf[5:37]) + _, p, valid := RecvMessage(&session, &msg1) + if !valid { + return nil, errors.New("handshake failed") + } + if len(p) != 0 { + return nil, errors.New("non-empty payload") + } + + _, msg2 := SendMessage(&session, nil) + var hdr [headerLen]byte + hdr[0] = msgTypeResponse + binary.BigEndian.PutUint16(hdr[1:3], 48) + if _, err := conn.Write(hdr[:]); err != nil { + return nil, err + } + if _, err := conn.Write(msg2.ne[:]); err != nil { + return nil, err + } + if _, err := conn.Write(msg2.ciphertext[:]); err != nil { + return nil, err + } + + if _, err := io.ReadFull(conn, buf[:3]); err != nil { + return nil, err + } + plen := int(binary.BigEndian.Uint16(buf[1:3])) + if _, err := io.ReadFull(conn, buf[:plen]); err != nil { + return nil, err + } + + msg3 := messagebuffer{ + ciphertext: buf[:plen], + } + _, p, valid = RecvMessage(&session, &msg3) + if !valid { + return nil, errors.New("transport message decryption failed") + } + + _, msg4 := SendMessage(&session, payload) + hdr[0] = msgTypeRecord + binary.BigEndian.PutUint16(hdr[1:3], uint16(len(msg4.ciphertext))) + if _, err := conn.Write(hdr[:]); err != nil { + return nil, err + } + if _, err := conn.Write(msg4.ciphertext); err != nil { + return nil, err + } + + return p, nil +} diff --git a/control/controlbase/messages.go b/control/controlbase/messages.go index 899378681..59073088f 100644 --- a/control/controlbase/messages.go +++ b/control/controlbase/messages.go @@ -1,87 +1,87 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlbase - -import "encoding/binary" - -const ( - // msgTypeInitiation frames carry a Noise IK handshake initiation message. - msgTypeInitiation = 1 - // msgTypeResponse frames carry a Noise IK handshake response message. - msgTypeResponse = 2 - // msgTypeError frames carry an unauthenticated human-readable - // error message. - // - // Errors reported in this message type must be treated as public - // hints only. They are not encrypted or authenticated, and so can - // be seen and tampered with on the wire. - msgTypeError = 3 - // msgTypeRecord frames carry session data bytes. - msgTypeRecord = 4 - - // headerLen is the size of the header on all messages except msgTypeInitiation. - headerLen = 3 - // initiationHeaderLen is the size of the header on all msgTypeInitiation messages. - initiationHeaderLen = 5 -) - -// initiationMessage is the protocol message sent from a client -// machine to a control server. -// -// 2b: protocol version -// 1b: message type (0x01) -// 2b: payload length (96) -// 5b: header (see headerLen for fields) -// 32b: client ephemeral public key (cleartext) -// 48b: client machine public key (encrypted) -// 16b: message tag (authenticates the whole message) -type initiationMessage [101]byte - -func mkInitiationMessage(protocolVersion uint16) initiationMessage { - var ret initiationMessage - binary.BigEndian.PutUint16(ret[:2], protocolVersion) - ret[2] = msgTypeInitiation - binary.BigEndian.PutUint16(ret[3:5], uint16(len(ret.Payload()))) - return ret -} - -func (m *initiationMessage) Header() []byte { return m[:initiationHeaderLen] } -func (m *initiationMessage) Payload() []byte { return m[initiationHeaderLen:] } - -func (m *initiationMessage) Version() uint16 { return binary.BigEndian.Uint16(m[:2]) } -func (m *initiationMessage) Type() byte { return m[2] } -func (m *initiationMessage) Length() int { return int(binary.BigEndian.Uint16(m[3:5])) } - -func (m *initiationMessage) EphemeralPub() []byte { - return m[initiationHeaderLen : initiationHeaderLen+32] -} -func (m *initiationMessage) MachinePub() []byte { - return m[initiationHeaderLen+32 : initiationHeaderLen+32+48] -} -func (m *initiationMessage) Tag() []byte { return m[initiationHeaderLen+32+48:] } - -// responseMessage is the protocol message sent from a control server -// to a client machine. -// -// 1b: message type (0x02) -// 2b: payload length (48) -// 32b: control ephemeral public key (cleartext) -// 16b: message tag (authenticates the whole message) -type responseMessage [51]byte - -func mkResponseMessage() responseMessage { - var ret responseMessage - ret[0] = msgTypeResponse - binary.BigEndian.PutUint16(ret[1:], uint16(len(ret.Payload()))) - return ret -} - -func (m *responseMessage) Header() []byte { return m[:headerLen] } -func (m *responseMessage) Payload() []byte { return m[headerLen:] } - -func (m *responseMessage) Type() byte { return m[0] } -func (m *responseMessage) Length() int { return int(binary.BigEndian.Uint16(m[1:3])) } - -func (m *responseMessage) EphemeralPub() []byte { return m[headerLen : headerLen+32] } -func (m *responseMessage) Tag() []byte { return m[headerLen+32:] } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlbase + +import "encoding/binary" + +const ( + // msgTypeInitiation frames carry a Noise IK handshake initiation message. + msgTypeInitiation = 1 + // msgTypeResponse frames carry a Noise IK handshake response message. + msgTypeResponse = 2 + // msgTypeError frames carry an unauthenticated human-readable + // error message. + // + // Errors reported in this message type must be treated as public + // hints only. They are not encrypted or authenticated, and so can + // be seen and tampered with on the wire. + msgTypeError = 3 + // msgTypeRecord frames carry session data bytes. + msgTypeRecord = 4 + + // headerLen is the size of the header on all messages except msgTypeInitiation. + headerLen = 3 + // initiationHeaderLen is the size of the header on all msgTypeInitiation messages. + initiationHeaderLen = 5 +) + +// initiationMessage is the protocol message sent from a client +// machine to a control server. +// +// 2b: protocol version +// 1b: message type (0x01) +// 2b: payload length (96) +// 5b: header (see headerLen for fields) +// 32b: client ephemeral public key (cleartext) +// 48b: client machine public key (encrypted) +// 16b: message tag (authenticates the whole message) +type initiationMessage [101]byte + +func mkInitiationMessage(protocolVersion uint16) initiationMessage { + var ret initiationMessage + binary.BigEndian.PutUint16(ret[:2], protocolVersion) + ret[2] = msgTypeInitiation + binary.BigEndian.PutUint16(ret[3:5], uint16(len(ret.Payload()))) + return ret +} + +func (m *initiationMessage) Header() []byte { return m[:initiationHeaderLen] } +func (m *initiationMessage) Payload() []byte { return m[initiationHeaderLen:] } + +func (m *initiationMessage) Version() uint16 { return binary.BigEndian.Uint16(m[:2]) } +func (m *initiationMessage) Type() byte { return m[2] } +func (m *initiationMessage) Length() int { return int(binary.BigEndian.Uint16(m[3:5])) } + +func (m *initiationMessage) EphemeralPub() []byte { + return m[initiationHeaderLen : initiationHeaderLen+32] +} +func (m *initiationMessage) MachinePub() []byte { + return m[initiationHeaderLen+32 : initiationHeaderLen+32+48] +} +func (m *initiationMessage) Tag() []byte { return m[initiationHeaderLen+32+48:] } + +// responseMessage is the protocol message sent from a control server +// to a client machine. +// +// 1b: message type (0x02) +// 2b: payload length (48) +// 32b: control ephemeral public key (cleartext) +// 16b: message tag (authenticates the whole message) +type responseMessage [51]byte + +func mkResponseMessage() responseMessage { + var ret responseMessage + ret[0] = msgTypeResponse + binary.BigEndian.PutUint16(ret[1:], uint16(len(ret.Payload()))) + return ret +} + +func (m *responseMessage) Header() []byte { return m[:headerLen] } +func (m *responseMessage) Payload() []byte { return m[headerLen:] } + +func (m *responseMessage) Type() byte { return m[0] } +func (m *responseMessage) Length() int { return int(binary.BigEndian.Uint16(m[1:3])) } + +func (m *responseMessage) EphemeralPub() []byte { return m[headerLen : headerLen+32] } +func (m *responseMessage) Tag() []byte { return m[headerLen+32:] } diff --git a/control/controlclient/sign.go b/control/controlclient/sign.go index 5e72f1cf4..e3a479c28 100644 --- a/control/controlclient/sign.go +++ b/control/controlclient/sign.go @@ -1,42 +1,42 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlclient - -import ( - "crypto" - "errors" - "fmt" - "time" - - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -var ( - errNoCertStore = errors.New("no certificate store") - errCertificateNotConfigured = errors.New("no certificate subject configured") - errUnsupportedSignatureVersion = errors.New("unsupported signature version") -) - -// HashRegisterRequest generates the hash required sign or verify a -// tailcfg.RegisterRequest. -func HashRegisterRequest( - version tailcfg.SignatureType, ts time.Time, serverURL string, deviceCert []byte, - serverPubKey, machinePubKey key.MachinePublic) ([]byte, error) { - h := crypto.SHA256.New() - - // hash.Hash.Write never returns an error, so we don't check for one here. - switch version { - case tailcfg.SignatureV1: - fmt.Fprintf(h, "%s%s%s%s%s", - ts.UTC().Format(time.RFC3339), serverURL, deviceCert, serverPubKey.ShortString(), machinePubKey.ShortString()) - case tailcfg.SignatureV2: - fmt.Fprintf(h, "%s%s%s%s%s", - ts.UTC().Format(time.RFC3339), serverURL, deviceCert, serverPubKey, machinePubKey) - default: - return nil, errUnsupportedSignatureVersion - } - - return h.Sum(nil), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlclient + +import ( + "crypto" + "errors" + "fmt" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +var ( + errNoCertStore = errors.New("no certificate store") + errCertificateNotConfigured = errors.New("no certificate subject configured") + errUnsupportedSignatureVersion = errors.New("unsupported signature version") +) + +// HashRegisterRequest generates the hash required sign or verify a +// tailcfg.RegisterRequest. +func HashRegisterRequest( + version tailcfg.SignatureType, ts time.Time, serverURL string, deviceCert []byte, + serverPubKey, machinePubKey key.MachinePublic) ([]byte, error) { + h := crypto.SHA256.New() + + // hash.Hash.Write never returns an error, so we don't check for one here. + switch version { + case tailcfg.SignatureV1: + fmt.Fprintf(h, "%s%s%s%s%s", + ts.UTC().Format(time.RFC3339), serverURL, deviceCert, serverPubKey.ShortString(), machinePubKey.ShortString()) + case tailcfg.SignatureV2: + fmt.Fprintf(h, "%s%s%s%s%s", + ts.UTC().Format(time.RFC3339), serverURL, deviceCert, serverPubKey, machinePubKey) + default: + return nil, errUnsupportedSignatureVersion + } + + return h.Sum(nil), nil +} diff --git a/control/controlclient/sign_supported_test.go b/control/controlclient/sign_supported_test.go index ca41794d1..e20349a4e 100644 --- a/control/controlclient/sign_supported_test.go +++ b/control/controlclient/sign_supported_test.go @@ -1,236 +1,236 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build windows && cgo - -package controlclient - -import ( - "crypto" - "crypto/x509" - "crypto/x509/pkix" - "errors" - "reflect" - "testing" - "time" - - "github.com/tailscale/certstore" -) - -const ( - testRootCommonName = "testroot" - testRootSubject = "CN=testroot" -) - -type testIdentity struct { - chain []*x509.Certificate -} - -func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate { - return []*x509.Certificate{ - { - NotBefore: notBefore, - NotAfter: notAfter, - PublicKeyAlgorithm: x509.RSA, - }, - { - Subject: pkix.Name{ - CommonName: rootCommonName, - }, - PublicKeyAlgorithm: x509.RSA, - }, - } -} - -func (t *testIdentity) Certificate() (*x509.Certificate, error) { - return t.chain[0], nil -} - -func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) { - return t.chain, nil -} - -func (t *testIdentity) Signer() (crypto.Signer, error) { - return nil, errors.New("not implemented") -} - -func (t *testIdentity) Delete() error { - return errors.New("not implemented") -} - -func (t *testIdentity) Close() {} - -func TestSelectIdentityFromSlice(t *testing.T) { - var times []time.Time - for _, ts := range []string{ - "2000-01-01T00:00:00Z", - "2001-01-01T00:00:00Z", - "2002-01-01T00:00:00Z", - "2003-01-01T00:00:00Z", - } { - tm, err := time.Parse(time.RFC3339, ts) - if err != nil { - t.Fatal(err) - } - times = append(times, tm) - } - - tests := []struct { - name string - subject string - ids []certstore.Identity - now time.Time - // wantIndex is an index into ids, or -1 for nil. - wantIndex int - }{ - { - name: "single unexpired identity", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[2]), - }, - }, - now: times[1], - wantIndex: 0, - }, - { - name: "single expired identity", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[1]), - }, - }, - now: times[2], - wantIndex: -1, - }, - { - name: "unrelated ids", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain("something", times[0], times[2]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[2]), - }, - &testIdentity{ - chain: makeChain("else", times[0], times[2]), - }, - }, - now: times[1], - wantIndex: 1, - }, - { - name: "expired with unrelated ids", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain("something", times[0], times[3]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[1]), - }, - &testIdentity{ - chain: makeChain("else", times[0], times[3]), - }, - }, - now: times[2], - wantIndex: -1, - }, - { - name: "one expired", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[1]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[1], times[3]), - }, - }, - now: times[2], - wantIndex: 1, - }, - { - name: "two certs both unexpired", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[3]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[1], times[3]), - }, - }, - now: times[2], - wantIndex: 1, - }, - { - name: "two unexpired one expired", - subject: testRootSubject, - ids: []certstore.Identity{ - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[3]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[1], times[3]), - }, - &testIdentity{ - chain: makeChain(testRootCommonName, times[0], times[1]), - }, - }, - now: times[2], - wantIndex: 1, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now) - - if gotId == nil && gotChain != nil { - t.Error("id is nil: got non-nil chain, want nil chain") - return - } - if gotId != nil && gotChain == nil { - t.Error("id is not nil: got nil chain, want non-nil chain") - return - } - if tt.wantIndex == -1 { - if gotId != nil { - t.Error("got non-nil id, want nil id") - } - return - } - if gotId == nil { - t.Error("got nil id, want non-nil id") - return - } - if gotId != tt.ids[tt.wantIndex] { - found := -1 - for i := range tt.ids { - if tt.ids[i] == gotId { - found = i - break - } - } - if found == -1 { - t.Errorf("got unknown id, want id at index %v", tt.wantIndex) - } else { - t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex) - } - } - - tid, ok := tt.ids[tt.wantIndex].(*testIdentity) - if !ok { - t.Error("got non-testIdentity, want testIdentity") - return - } - - if !reflect.DeepEqual(tid.chain, gotChain) { - t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows && cgo + +package controlclient + +import ( + "crypto" + "crypto/x509" + "crypto/x509/pkix" + "errors" + "reflect" + "testing" + "time" + + "github.com/tailscale/certstore" +) + +const ( + testRootCommonName = "testroot" + testRootSubject = "CN=testroot" +) + +type testIdentity struct { + chain []*x509.Certificate +} + +func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate { + return []*x509.Certificate{ + { + NotBefore: notBefore, + NotAfter: notAfter, + PublicKeyAlgorithm: x509.RSA, + }, + { + Subject: pkix.Name{ + CommonName: rootCommonName, + }, + PublicKeyAlgorithm: x509.RSA, + }, + } +} + +func (t *testIdentity) Certificate() (*x509.Certificate, error) { + return t.chain[0], nil +} + +func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) { + return t.chain, nil +} + +func (t *testIdentity) Signer() (crypto.Signer, error) { + return nil, errors.New("not implemented") +} + +func (t *testIdentity) Delete() error { + return errors.New("not implemented") +} + +func (t *testIdentity) Close() {} + +func TestSelectIdentityFromSlice(t *testing.T) { + var times []time.Time + for _, ts := range []string{ + "2000-01-01T00:00:00Z", + "2001-01-01T00:00:00Z", + "2002-01-01T00:00:00Z", + "2003-01-01T00:00:00Z", + } { + tm, err := time.Parse(time.RFC3339, ts) + if err != nil { + t.Fatal(err) + } + times = append(times, tm) + } + + tests := []struct { + name string + subject string + ids []certstore.Identity + now time.Time + // wantIndex is an index into ids, or -1 for nil. + wantIndex int + }{ + { + name: "single unexpired identity", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[2]), + }, + }, + now: times[1], + wantIndex: 0, + }, + { + name: "single expired identity", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[1]), + }, + }, + now: times[2], + wantIndex: -1, + }, + { + name: "unrelated ids", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain("something", times[0], times[2]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[2]), + }, + &testIdentity{ + chain: makeChain("else", times[0], times[2]), + }, + }, + now: times[1], + wantIndex: 1, + }, + { + name: "expired with unrelated ids", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain("something", times[0], times[3]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[1]), + }, + &testIdentity{ + chain: makeChain("else", times[0], times[3]), + }, + }, + now: times[2], + wantIndex: -1, + }, + { + name: "one expired", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[1]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[1], times[3]), + }, + }, + now: times[2], + wantIndex: 1, + }, + { + name: "two certs both unexpired", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[3]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[1], times[3]), + }, + }, + now: times[2], + wantIndex: 1, + }, + { + name: "two unexpired one expired", + subject: testRootSubject, + ids: []certstore.Identity{ + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[3]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[1], times[3]), + }, + &testIdentity{ + chain: makeChain(testRootCommonName, times[0], times[1]), + }, + }, + now: times[2], + wantIndex: 1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now) + + if gotId == nil && gotChain != nil { + t.Error("id is nil: got non-nil chain, want nil chain") + return + } + if gotId != nil && gotChain == nil { + t.Error("id is not nil: got nil chain, want non-nil chain") + return + } + if tt.wantIndex == -1 { + if gotId != nil { + t.Error("got non-nil id, want nil id") + } + return + } + if gotId == nil { + t.Error("got nil id, want non-nil id") + return + } + if gotId != tt.ids[tt.wantIndex] { + found := -1 + for i := range tt.ids { + if tt.ids[i] == gotId { + found = i + break + } + } + if found == -1 { + t.Errorf("got unknown id, want id at index %v", tt.wantIndex) + } else { + t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex) + } + } + + tid, ok := tt.ids[tt.wantIndex].(*testIdentity) + if !ok { + t.Error("got non-testIdentity, want testIdentity") + return + } + + if !reflect.DeepEqual(tid.chain, gotChain) { + t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex) + } + }) + } +} diff --git a/control/controlclient/sign_unsupported.go b/control/controlclient/sign_unsupported.go index 4ec40d502..5e161dcbc 100644 --- a/control/controlclient/sign_unsupported.go +++ b/control/controlclient/sign_unsupported.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package controlclient - -import ( - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -// signRegisterRequest on non-supported platforms always returns errNoCertStore. -func signRegisterRequest(req *tailcfg.RegisterRequest, serverURL string, serverPubKey, machinePubKey key.MachinePublic) error { - return errNoCertStore -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package controlclient + +import ( + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// signRegisterRequest on non-supported platforms always returns errNoCertStore. +func signRegisterRequest(req *tailcfg.RegisterRequest, serverURL string, serverPubKey, machinePubKey key.MachinePublic) error { + return errNoCertStore +} diff --git a/control/controlclient/status.go b/control/controlclient/status.go index 7dba14d3f..d0fdf80d7 100644 --- a/control/controlclient/status.go +++ b/control/controlclient/status.go @@ -1,125 +1,125 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlclient - -import ( - "encoding/json" - "fmt" - "reflect" - - "tailscale.com/types/netmap" - "tailscale.com/types/persist" - "tailscale.com/types/structs" -) - -// State is the high-level state of the client. It is used only in -// unit tests for proper sequencing, don't depend on it anywhere else. -// -// TODO(apenwarr): eliminate the state, as it's now obsolete. -// -// apenwarr: Historical note: controlclient.Auto was originally -// intended to be the state machine for the whole tailscale client, but that -// turned out to not be the right abstraction layer, and it moved to -// ipn.Backend. Since ipn.Backend now has a state machine, it would be -// much better if controlclient could be a simple stateless API. But the -// current server-side API (two interlocking polling https calls) makes that -// very hard to implement. A server side API change could untangle this and -// remove all the statefulness. -type State int - -const ( - StateNew = State(iota) - StateNotAuthenticated - StateAuthenticating - StateURLVisitRequired - StateAuthenticated - StateSynchronized // connected and received map update -) - -func (s State) AppendText(b []byte) ([]byte, error) { - return append(b, s.String()...), nil -} - -func (s State) MarshalText() ([]byte, error) { - return []byte(s.String()), nil -} - -func (s State) String() string { - switch s { - case StateNew: - return "state:new" - case StateNotAuthenticated: - return "state:not-authenticated" - case StateAuthenticating: - return "state:authenticating" - case StateURLVisitRequired: - return "state:url-visit-required" - case StateAuthenticated: - return "state:authenticated" - case StateSynchronized: - return "state:synchronized" - default: - return fmt.Sprintf("state:unknown:%d", int(s)) - } -} - -type Status struct { - _ structs.Incomparable - - // Err, if non-nil, is an error that occurred while logging in. - // - // If it's of type UserVisibleError then it's meant to be shown to users in - // their Tailscale client. Otherwise it's just logged to tailscaled's logs. - Err error - - // URL, if non-empty, is the interactive URL to visit to finish logging in. - URL string - - // NetMap is the latest server-pushed state of the tailnet network. - NetMap *netmap.NetworkMap - - // Persist, when Valid, is the locally persisted configuration. - // - // TODO(bradfitz,maisem): clarify this. - Persist persist.PersistView - - // state is the internal state. It should not be exposed outside this - // package, but we have some automated tests elsewhere that need to - // use it via the StateForTest accessor. - // TODO(apenwarr): Unexport or remove these. - state State -} - -// LoginFinished reports whether the controlclient is in its "StateAuthenticated" -// state where it's in a happy register state but not yet in a map poll. -// -// TODO(bradfitz): delete this and everything around Status.state. -func (s *Status) LoginFinished() bool { return s.state == StateAuthenticated } - -// StateForTest returns the internal state of s for tests only. -func (s *Status) StateForTest() State { return s.state } - -// SetStateForTest sets the internal state of s for tests only. -func (s *Status) SetStateForTest(state State) { s.state = state } - -// Equal reports whether s and s2 are equal. -func (s *Status) Equal(s2 *Status) bool { - if s == nil && s2 == nil { - return true - } - return s != nil && s2 != nil && - s.Err == s2.Err && - s.URL == s2.URL && - s.state == s2.state && - reflect.DeepEqual(s.Persist, s2.Persist) && - reflect.DeepEqual(s.NetMap, s2.NetMap) -} - -func (s Status) String() string { - b, err := json.MarshalIndent(s, "", "\t") - if err != nil { - panic(err) - } - return s.state.String() + " " + string(b) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlclient + +import ( + "encoding/json" + "fmt" + "reflect" + + "tailscale.com/types/netmap" + "tailscale.com/types/persist" + "tailscale.com/types/structs" +) + +// State is the high-level state of the client. It is used only in +// unit tests for proper sequencing, don't depend on it anywhere else. +// +// TODO(apenwarr): eliminate the state, as it's now obsolete. +// +// apenwarr: Historical note: controlclient.Auto was originally +// intended to be the state machine for the whole tailscale client, but that +// turned out to not be the right abstraction layer, and it moved to +// ipn.Backend. Since ipn.Backend now has a state machine, it would be +// much better if controlclient could be a simple stateless API. But the +// current server-side API (two interlocking polling https calls) makes that +// very hard to implement. A server side API change could untangle this and +// remove all the statefulness. +type State int + +const ( + StateNew = State(iota) + StateNotAuthenticated + StateAuthenticating + StateURLVisitRequired + StateAuthenticated + StateSynchronized // connected and received map update +) + +func (s State) AppendText(b []byte) ([]byte, error) { + return append(b, s.String()...), nil +} + +func (s State) MarshalText() ([]byte, error) { + return []byte(s.String()), nil +} + +func (s State) String() string { + switch s { + case StateNew: + return "state:new" + case StateNotAuthenticated: + return "state:not-authenticated" + case StateAuthenticating: + return "state:authenticating" + case StateURLVisitRequired: + return "state:url-visit-required" + case StateAuthenticated: + return "state:authenticated" + case StateSynchronized: + return "state:synchronized" + default: + return fmt.Sprintf("state:unknown:%d", int(s)) + } +} + +type Status struct { + _ structs.Incomparable + + // Err, if non-nil, is an error that occurred while logging in. + // + // If it's of type UserVisibleError then it's meant to be shown to users in + // their Tailscale client. Otherwise it's just logged to tailscaled's logs. + Err error + + // URL, if non-empty, is the interactive URL to visit to finish logging in. + URL string + + // NetMap is the latest server-pushed state of the tailnet network. + NetMap *netmap.NetworkMap + + // Persist, when Valid, is the locally persisted configuration. + // + // TODO(bradfitz,maisem): clarify this. + Persist persist.PersistView + + // state is the internal state. It should not be exposed outside this + // package, but we have some automated tests elsewhere that need to + // use it via the StateForTest accessor. + // TODO(apenwarr): Unexport or remove these. + state State +} + +// LoginFinished reports whether the controlclient is in its "StateAuthenticated" +// state where it's in a happy register state but not yet in a map poll. +// +// TODO(bradfitz): delete this and everything around Status.state. +func (s *Status) LoginFinished() bool { return s.state == StateAuthenticated } + +// StateForTest returns the internal state of s for tests only. +func (s *Status) StateForTest() State { return s.state } + +// SetStateForTest sets the internal state of s for tests only. +func (s *Status) SetStateForTest(state State) { s.state = state } + +// Equal reports whether s and s2 are equal. +func (s *Status) Equal(s2 *Status) bool { + if s == nil && s2 == nil { + return true + } + return s != nil && s2 != nil && + s.Err == s2.Err && + s.URL == s2.URL && + s.state == s2.state && + reflect.DeepEqual(s.Persist, s2.Persist) && + reflect.DeepEqual(s.NetMap, s2.NetMap) +} + +func (s Status) String() string { + b, err := json.MarshalIndent(s, "", "\t") + if err != nil { + panic(err) + } + return s.state.String() + " " + string(b) +} diff --git a/control/controlhttp/client_common.go b/control/controlhttp/client_common.go index 72a89e3cd..dd94e93cd 100644 --- a/control/controlhttp/client_common.go +++ b/control/controlhttp/client_common.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlhttp - -import ( - "tailscale.com/control/controlbase" -) - -// ClientConn is a Tailscale control client as returned by the Dialer. -// -// It's effectively just a *controlbase.Conn (which it embeds) with -// optional metadata. -type ClientConn struct { - // Conn is the noise connection. - *controlbase.Conn -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlhttp + +import ( + "tailscale.com/control/controlbase" +) + +// ClientConn is a Tailscale control client as returned by the Dialer. +// +// It's effectively just a *controlbase.Conn (which it embeds) with +// optional metadata. +type ClientConn struct { + // Conn is the noise connection. + *controlbase.Conn +} diff --git a/derp/README.md b/derp/README.md index acd986ea9..16877020d 100644 --- a/derp/README.md +++ b/derp/README.md @@ -1,61 +1,61 @@ -# DERP - -This directory (and subdirectories) contain the DERP code. The server itself is -in `../cmd/derper`. - -DERP is a packet relay system (client and servers) where peers are addressed -using WireGuard public keys instead of IP addresses. - -It relays two types of packets: - -* "Disco" discovery messages (see `../disco`) as the a side channel during [NAT - traversal](https://tailscale.com/blog/how-nat-traversal-works/). - -* Encrypted WireGuard packets as the fallback of last resort when UDP is blocked - or NAT traversal fails. - -## DERP Map - -Each client receives a "[DERP -Map](https://pkg.go.dev/tailscale.com/tailcfg#DERPMap)" from the coordination -server describing the DERP servers the client should try to use. - -The client picks its home "DERP home" based on latency. This is done to keep -costs low by avoid using cloud load balancers (pricey) or anycast, which would -necessarily require server-side routing between DERP regions. - -Clients pick their DERP home and report it to the coordination server which -shares it to all the peers in the tailnet. When a peer wants to send a packet -and it doesn't already have a WireGuard session open, it sends disco messages -(some direct, and some over DERP), trying to do the NAT traversal. The client -will make connections to multiple DERP regions as needed. Only the DERP home -region connection needs to be alive forever. - -## DERP Regions - -Tailscale runs 1 or more DERP nodes (instances of `cmd/derper`) in various -geographic regions to make sure users have low latency to their DERP home. - -Regions generally have multiple nodes per region "meshed" (routing to each -other) together for redundancy: it allows for cloud failures or upgrades without -kicking users out to a higher latency region. Instead, clients will reconnect to -the next node in the region. Each node in the region is required to to be meshed -with every other node in the region and forward packets to the other nodes in -the region. Packets are forwarded only one hop within the region. There is no -routing between regions. The assumption is that the mesh TCP connections are -over a VPC that's very fast, low latency, and not charged per byte. The -coordination server assigns the list of nodes in a region as a function of the -tailnet, so all nodes within a tailnet should generally be on the same node and -not require forwarding. Only after a failure do clients of a particular tailnet -get split between nodes in a region and require inter-node forwarding. But over -time it balances back out. There's also an admin-only DERP frame type to force -close the TCP connection of a particular client to force them to reconnect to -their primary if the operator wants to force things to balance out sooner. -(Using the `(*derphttp.Client).ClosePeer` method, as used by Tailscale's -internal rarely-used `cmd/derpprune` maintenance tool) - -We generally run a minimum of three nodes in a region not for quorum reasons -(there's no voting) but just because two is too uncomfortably few for cascading -failure reasons: if you're running two nodes at 51% load (CPU, memory, etc) and -then one fails, that makes the second one fail. With three or more nodes, you +# DERP + +This directory (and subdirectories) contain the DERP code. The server itself is +in `../cmd/derper`. + +DERP is a packet relay system (client and servers) where peers are addressed +using WireGuard public keys instead of IP addresses. + +It relays two types of packets: + +* "Disco" discovery messages (see `../disco`) as the a side channel during [NAT + traversal](https://tailscale.com/blog/how-nat-traversal-works/). + +* Encrypted WireGuard packets as the fallback of last resort when UDP is blocked + or NAT traversal fails. + +## DERP Map + +Each client receives a "[DERP +Map](https://pkg.go.dev/tailscale.com/tailcfg#DERPMap)" from the coordination +server describing the DERP servers the client should try to use. + +The client picks its home "DERP home" based on latency. This is done to keep +costs low by avoid using cloud load balancers (pricey) or anycast, which would +necessarily require server-side routing between DERP regions. + +Clients pick their DERP home and report it to the coordination server which +shares it to all the peers in the tailnet. When a peer wants to send a packet +and it doesn't already have a WireGuard session open, it sends disco messages +(some direct, and some over DERP), trying to do the NAT traversal. The client +will make connections to multiple DERP regions as needed. Only the DERP home +region connection needs to be alive forever. + +## DERP Regions + +Tailscale runs 1 or more DERP nodes (instances of `cmd/derper`) in various +geographic regions to make sure users have low latency to their DERP home. + +Regions generally have multiple nodes per region "meshed" (routing to each +other) together for redundancy: it allows for cloud failures or upgrades without +kicking users out to a higher latency region. Instead, clients will reconnect to +the next node in the region. Each node in the region is required to to be meshed +with every other node in the region and forward packets to the other nodes in +the region. Packets are forwarded only one hop within the region. There is no +routing between regions. The assumption is that the mesh TCP connections are +over a VPC that's very fast, low latency, and not charged per byte. The +coordination server assigns the list of nodes in a region as a function of the +tailnet, so all nodes within a tailnet should generally be on the same node and +not require forwarding. Only after a failure do clients of a particular tailnet +get split between nodes in a region and require inter-node forwarding. But over +time it balances back out. There's also an admin-only DERP frame type to force +close the TCP connection of a particular client to force them to reconnect to +their primary if the operator wants to force things to balance out sooner. +(Using the `(*derphttp.Client).ClosePeer` method, as used by Tailscale's +internal rarely-used `cmd/derpprune` maintenance tool) + +We generally run a minimum of three nodes in a region not for quorum reasons +(there's no voting) but just because two is too uncomfortably few for cascading +failure reasons: if you're running two nodes at 51% load (CPU, memory, etc) and +then one fails, that makes the second one fail. With three or more nodes, you can run each node a bit hotter. \ No newline at end of file diff --git a/derp/testdata/example_ss.txt b/derp/testdata/example_ss.txt index ae25003b2..2885f1bc1 100644 --- a/derp/testdata/example_ss.txt +++ b/derp/testdata/example_ss.txt @@ -1,8 +1,8 @@ -ESTAB 0 0 10.255.1.11:35238 34.210.105.16:https - cubic wscale:7,7 rto:236 rtt:34.14/3.432 ato:40 mss:1448 pmtu:1500 rcvmss:1448 advmss:1448 cwnd:8 ssthresh:6 bytes_sent:38056577 bytes_retrans:2918 bytes_acked:38053660 bytes_received:6973211 segs_out:165090 segs_in:124227 data_segs_out:78018 data_segs_in:71645 send 2.71Mbps lastsnd:1156 lastrcv:1120 lastack:1120 pacing_rate 3.26Mbps delivery_rate 2.35Mbps delivered:78017 app_limited busy:2586132ms retrans:0/6 dsack_dups:4 reordering:5 reord_seen:15 rcv_rtt:126355 rcv_space:65780 rcv_ssthresh:541928 minrtt:26.632 -ESTAB 0 80 100.79.58.14:ssh 100.95.73.104:58145 - cubic wscale:6,7 rto:224 rtt:23.051/2.03 ato:172 mss:1228 pmtu:1280 rcvmss:1228 advmss:1228 cwnd:10 ssthresh:94 bytes_sent:1591815 bytes_retrans:944 bytes_acked:1590791 bytes_received:158925 segs_out:8070 segs_in:8858 data_segs_out:7452 data_segs_in:3789 send 4.26Mbps lastsnd:4 lastrcv:4 lastack:4 pacing_rate 8.52Mbps delivery_rate 10.9Mbps delivered:7451 app_limited busy:61656ms unacked:2 retrans:0/10 dsack_dups:10 rcv_rtt:174712 rcv_space:65025 rcv_ssthresh:64296 minrtt:16.186 -ESTAB 0 374 10.255.1.11:43254 167.172.206.31:https - cubic wscale:7,7 rto:224 rtt:22.55/1.941 ato:40 mss:1448 pmtu:1500 rcvmss:1448 advmss:1448 cwnd:6 ssthresh:4 bytes_sent:14594668 bytes_retrans:173314 bytes_acked:14420981 bytes_received:4207111 segs_out:80566 segs_in:70310 data_segs_out:24317 data_segs_in:20365 send 3.08Mbps lastsnd:4 lastrcv:4 lastack:4 pacing_rate 3.7Mbps delivery_rate 3.05Mbps delivered:24111 app_limited busy:184820ms unacked:2 retrans:0/185 dsack_dups:1 reord_seen:3 rcv_rtt:651.262 rcv_space:226657 rcv_ssthresh:1557136 minrtt:10.18 -ESTAB 0 0 10.255.1.11:33036 3.121.18.47:https - cubic wscale:7,7 rto:372 rtt:168.408/2.044 ato:40 mss:1448 pmtu:1500 rcvmss:1448 advmss:1448 cwnd:10 bytes_sent:27500 bytes_acked:27501 bytes_received:1386524 segs_out:10990 segs_in:11037 data_segs_out:303 data_segs_in:3414 send 688kbps lastsnd:125776 lastrcv:9640 lastack:22760 pacing_rate 1.38Mbps delivery_rate 482kbps delivered:304 app_limited busy:43024ms rcv_rtt:3345.12 rcv_space:62431 rcv_ssthresh:760472 minrtt:168.867 +ESTAB 0 0 10.255.1.11:35238 34.210.105.16:https + cubic wscale:7,7 rto:236 rtt:34.14/3.432 ato:40 mss:1448 pmtu:1500 rcvmss:1448 advmss:1448 cwnd:8 ssthresh:6 bytes_sent:38056577 bytes_retrans:2918 bytes_acked:38053660 bytes_received:6973211 segs_out:165090 segs_in:124227 data_segs_out:78018 data_segs_in:71645 send 2.71Mbps lastsnd:1156 lastrcv:1120 lastack:1120 pacing_rate 3.26Mbps delivery_rate 2.35Mbps delivered:78017 app_limited busy:2586132ms retrans:0/6 dsack_dups:4 reordering:5 reord_seen:15 rcv_rtt:126355 rcv_space:65780 rcv_ssthresh:541928 minrtt:26.632 +ESTAB 0 80 100.79.58.14:ssh 100.95.73.104:58145 + cubic wscale:6,7 rto:224 rtt:23.051/2.03 ato:172 mss:1228 pmtu:1280 rcvmss:1228 advmss:1228 cwnd:10 ssthresh:94 bytes_sent:1591815 bytes_retrans:944 bytes_acked:1590791 bytes_received:158925 segs_out:8070 segs_in:8858 data_segs_out:7452 data_segs_in:3789 send 4.26Mbps lastsnd:4 lastrcv:4 lastack:4 pacing_rate 8.52Mbps delivery_rate 10.9Mbps delivered:7451 app_limited busy:61656ms unacked:2 retrans:0/10 dsack_dups:10 rcv_rtt:174712 rcv_space:65025 rcv_ssthresh:64296 minrtt:16.186 +ESTAB 0 374 10.255.1.11:43254 167.172.206.31:https + cubic wscale:7,7 rto:224 rtt:22.55/1.941 ato:40 mss:1448 pmtu:1500 rcvmss:1448 advmss:1448 cwnd:6 ssthresh:4 bytes_sent:14594668 bytes_retrans:173314 bytes_acked:14420981 bytes_received:4207111 segs_out:80566 segs_in:70310 data_segs_out:24317 data_segs_in:20365 send 3.08Mbps lastsnd:4 lastrcv:4 lastack:4 pacing_rate 3.7Mbps delivery_rate 3.05Mbps delivered:24111 app_limited busy:184820ms unacked:2 retrans:0/185 dsack_dups:1 reord_seen:3 rcv_rtt:651.262 rcv_space:226657 rcv_ssthresh:1557136 minrtt:10.18 +ESTAB 0 0 10.255.1.11:33036 3.121.18.47:https + cubic wscale:7,7 rto:372 rtt:168.408/2.044 ato:40 mss:1448 pmtu:1500 rcvmss:1448 advmss:1448 cwnd:10 bytes_sent:27500 bytes_acked:27501 bytes_received:1386524 segs_out:10990 segs_in:11037 data_segs_out:303 data_segs_in:3414 send 688kbps lastsnd:125776 lastrcv:9640 lastack:22760 pacing_rate 1.38Mbps delivery_rate 482kbps delivered:304 app_limited busy:43024ms rcv_rtt:3345.12 rcv_space:62431 rcv_ssthresh:760472 minrtt:168.867 diff --git a/disco/disco_fuzzer.go b/disco/disco_fuzzer.go index 0deede050..b9ffabfb0 100644 --- a/disco/disco_fuzzer.go +++ b/disco/disco_fuzzer.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -//go:build gofuzz - -package disco - -func Fuzz(data []byte) int { - m, _ := Parse(data) - - newBytes := m.AppendMarshal(data) - parsedMarshall, _ := Parse(newBytes) - - if m != parsedMarshall { - panic("Parsing error") - } - return 1 -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +//go:build gofuzz + +package disco + +func Fuzz(data []byte) int { + m, _ := Parse(data) + + newBytes := m.AppendMarshal(data) + parsedMarshall, _ := Parse(newBytes) + + if m != parsedMarshall { + panic("Parsing error") + } + return 1 +} diff --git a/disco/disco_test.go b/disco/disco_test.go index 045425eb7..1a56324a5 100644 --- a/disco/disco_test.go +++ b/disco/disco_test.go @@ -1,118 +1,118 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package disco - -import ( - "fmt" - "net/netip" - "reflect" - "strings" - "testing" - - "go4.org/mem" - "tailscale.com/types/key" -) - -func TestMarshalAndParse(t *testing.T) { - tests := []struct { - name string - want string - m Message - }{ - { - name: "ping", - m: &Ping{ - TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - }, - want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c", - }, - { - name: "ping_with_nodekey_src", - m: &Ping{ - TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - NodeKey: key.NodePublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), - }, - want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f", - }, - { - name: "ping_with_padding", - m: &Ping{ - TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - Padding: 3, - }, - want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 00 00", - }, - { - name: "ping_with_padding_and_nodekey_src", - m: &Ping{ - TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - NodeKey: key.NodePublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), - Padding: 3, - }, - want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 00 00", - }, - { - name: "pong", - m: &Pong{ - TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - Src: mustIPPort("2.3.4.5:1234"), - }, - want: "02 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 00 00 00 00 00 00 00 00 00 ff ff 02 03 04 05 04 d2", - }, - { - name: "pongv6", - m: &Pong{ - TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - Src: mustIPPort("[fed0::12]:6666"), - }, - want: "02 00 01 02 03 04 05 06 07 08 09 0a 0b 0c fe d0 00 00 00 00 00 00 00 00 00 00 00 00 00 12 1a 0a", - }, - { - name: "call_me_maybe", - m: &CallMeMaybe{}, - want: "03 00", - }, - { - name: "call_me_maybe_endpoints", - m: &CallMeMaybe{ - MyNumber: []netip.AddrPort{ - netip.MustParseAddrPort("1.2.3.4:567"), - netip.MustParseAddrPort("[2001::3456]:789"), - }, - }, - want: "03 00 00 00 00 00 00 00 00 00 00 00 ff ff 01 02 03 04 02 37 20 01 00 00 00 00 00 00 00 00 00 00 00 00 34 56 03 15", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - foo := []byte("foo") - got := string(tt.m.AppendMarshal(foo)) - got, ok := strings.CutPrefix(got, "foo") - if !ok { - t.Fatalf("didn't start with foo: got %q", got) - } - - gotHex := fmt.Sprintf("% x", got) - if gotHex != tt.want { - t.Fatalf("wrong marshal\n got: %s\nwant: %s\n", gotHex, tt.want) - } - - back, err := Parse([]byte(got)) - if err != nil { - t.Fatalf("parse back: %v", err) - } - if !reflect.DeepEqual(back, tt.m) { - t.Errorf("message in %+v doesn't match Parse back result %+v", tt.m, back) - } - }) - } -} - -func mustIPPort(s string) netip.AddrPort { - ipp, err := netip.ParseAddrPort(s) - if err != nil { - panic(err) - } - return ipp -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package disco + +import ( + "fmt" + "net/netip" + "reflect" + "strings" + "testing" + + "go4.org/mem" + "tailscale.com/types/key" +) + +func TestMarshalAndParse(t *testing.T) { + tests := []struct { + name string + want string + m Message + }{ + { + name: "ping", + m: &Ping{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + }, + want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c", + }, + { + name: "ping_with_nodekey_src", + m: &Ping{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + NodeKey: key.NodePublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), + }, + want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f", + }, + { + name: "ping_with_padding", + m: &Ping{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + Padding: 3, + }, + want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 00 00", + }, + { + name: "ping_with_padding_and_nodekey_src", + m: &Ping{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + NodeKey: key.NodePublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), + Padding: 3, + }, + want: "01 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 00 00", + }, + { + name: "pong", + m: &Pong{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + Src: mustIPPort("2.3.4.5:1234"), + }, + want: "02 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 00 00 00 00 00 00 00 00 00 00 ff ff 02 03 04 05 04 d2", + }, + { + name: "pongv6", + m: &Pong{ + TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + Src: mustIPPort("[fed0::12]:6666"), + }, + want: "02 00 01 02 03 04 05 06 07 08 09 0a 0b 0c fe d0 00 00 00 00 00 00 00 00 00 00 00 00 00 12 1a 0a", + }, + { + name: "call_me_maybe", + m: &CallMeMaybe{}, + want: "03 00", + }, + { + name: "call_me_maybe_endpoints", + m: &CallMeMaybe{ + MyNumber: []netip.AddrPort{ + netip.MustParseAddrPort("1.2.3.4:567"), + netip.MustParseAddrPort("[2001::3456]:789"), + }, + }, + want: "03 00 00 00 00 00 00 00 00 00 00 00 ff ff 01 02 03 04 02 37 20 01 00 00 00 00 00 00 00 00 00 00 00 00 34 56 03 15", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + foo := []byte("foo") + got := string(tt.m.AppendMarshal(foo)) + got, ok := strings.CutPrefix(got, "foo") + if !ok { + t.Fatalf("didn't start with foo: got %q", got) + } + + gotHex := fmt.Sprintf("% x", got) + if gotHex != tt.want { + t.Fatalf("wrong marshal\n got: %s\nwant: %s\n", gotHex, tt.want) + } + + back, err := Parse([]byte(got)) + if err != nil { + t.Fatalf("parse back: %v", err) + } + if !reflect.DeepEqual(back, tt.m) { + t.Errorf("message in %+v doesn't match Parse back result %+v", tt.m, back) + } + }) + } +} + +func mustIPPort(s string) netip.AddrPort { + ipp, err := netip.ParseAddrPort(s) + if err != nil { + panic(err) + } + return ipp +} diff --git a/disco/pcap.go b/disco/pcap.go index 5d60ceb28..710354248 100644 --- a/disco/pcap.go +++ b/disco/pcap.go @@ -1,40 +1,40 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package disco - -import ( - "bytes" - "encoding/binary" - "net/netip" - - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -// ToPCAPFrame marshals the bytes for a pcap record that describe a disco frame. -// -// Warning: Alloc garbage. Acceptable while capturing. -func ToPCAPFrame(src netip.AddrPort, derpNodeSrc key.NodePublic, payload []byte) []byte { - var ( - b bytes.Buffer - flag uint8 - ) - b.Grow(128) // Most disco frames will probably be smaller than this. - - if src.Addr() == tailcfg.DerpMagicIPAddr { - flag |= 0x01 - } - b.WriteByte(flag) // 1b: flag - - derpSrc := derpNodeSrc.Raw32() - b.Write(derpSrc[:]) // 32b: derp public key - binary.Write(&b, binary.LittleEndian, uint16(src.Port())) // 2b: port - addr, _ := src.Addr().MarshalBinary() - binary.Write(&b, binary.LittleEndian, uint16(len(addr))) // 2b: len(addr) - b.Write(addr) // Xb: addr - binary.Write(&b, binary.LittleEndian, uint16(len(payload))) // 2b: len(payload) - b.Write(payload) // Xb: payload - - return b.Bytes() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package disco + +import ( + "bytes" + "encoding/binary" + "net/netip" + + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// ToPCAPFrame marshals the bytes for a pcap record that describe a disco frame. +// +// Warning: Alloc garbage. Acceptable while capturing. +func ToPCAPFrame(src netip.AddrPort, derpNodeSrc key.NodePublic, payload []byte) []byte { + var ( + b bytes.Buffer + flag uint8 + ) + b.Grow(128) // Most disco frames will probably be smaller than this. + + if src.Addr() == tailcfg.DerpMagicIPAddr { + flag |= 0x01 + } + b.WriteByte(flag) // 1b: flag + + derpSrc := derpNodeSrc.Raw32() + b.Write(derpSrc[:]) // 32b: derp public key + binary.Write(&b, binary.LittleEndian, uint16(src.Port())) // 2b: port + addr, _ := src.Addr().MarshalBinary() + binary.Write(&b, binary.LittleEndian, uint16(len(addr))) // 2b: len(addr) + b.Write(addr) // Xb: addr + binary.Write(&b, binary.LittleEndian, uint16(len(payload))) // 2b: len(payload) + b.Write(payload) // Xb: payload + + return b.Bytes() +} diff --git a/docs/bird/sample_bird.conf b/docs/bird/sample_bird.conf index 87222c59a..ed38e66c5 100644 --- a/docs/bird/sample_bird.conf +++ b/docs/bird/sample_bird.conf @@ -1,16 +1,16 @@ -log syslog all; - -protocol device { - scan time 10; -} - -protocol bgp { - local as 64001; - neighbor 10.40.2.101 as 64002; - ipv4 { - import none; - export all; - }; -} - -include "tailscale_bird.conf"; +log syslog all; + +protocol device { + scan time 10; +} + +protocol bgp { + local as 64001; + neighbor 10.40.2.101 as 64002; + ipv4 { + import none; + export all; + }; +} + +include "tailscale_bird.conf"; diff --git a/docs/bird/tailscale_bird.conf b/docs/bird/tailscale_bird.conf index a5f430747..8211a50a3 100644 --- a/docs/bird/tailscale_bird.conf +++ b/docs/bird/tailscale_bird.conf @@ -1,4 +1,4 @@ -protocol static tailscale { - ipv4; - route 100.64.0.0/10 via "tailscale0"; -} +protocol static tailscale { + ipv4; + route 100.64.0.0/10 via "tailscale0"; +} diff --git a/docs/k8s/Makefile b/docs/k8s/Makefile index 107c1c136..55804c857 100644 --- a/docs/k8s/Makefile +++ b/docs/k8s/Makefile @@ -1,25 +1,25 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -TS_ROUTES ?= "" -SA_NAME ?= tailscale -TS_KUBE_SECRET ?= tailscale - -rbac: - @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" role.yaml - @echo "---" - @sed -e "s;{{SA_NAME}};$(SA_NAME);g" rolebinding.yaml - @echo "---" - @sed -e "s;{{SA_NAME}};$(SA_NAME);g" sa.yaml - -sidecar: - @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" sidecar.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" - -userspace-sidecar: - @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" userspace-sidecar.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" - -proxy: - @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" proxy.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" | sed -e "s;{{TS_DEST_IP}};$(TS_DEST_IP);g" - -subnet-router: - @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" subnet.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" | sed -e "s;{{TS_ROUTES}};$(TS_ROUTES);g" +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +TS_ROUTES ?= "" +SA_NAME ?= tailscale +TS_KUBE_SECRET ?= tailscale + +rbac: + @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" role.yaml + @echo "---" + @sed -e "s;{{SA_NAME}};$(SA_NAME);g" rolebinding.yaml + @echo "---" + @sed -e "s;{{SA_NAME}};$(SA_NAME);g" sa.yaml + +sidecar: + @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" sidecar.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" + +userspace-sidecar: + @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" userspace-sidecar.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" + +proxy: + @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" proxy.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" | sed -e "s;{{TS_DEST_IP}};$(TS_DEST_IP);g" + +subnet-router: + @sed -e "s;{{TS_KUBE_SECRET}};$(TS_KUBE_SECRET);g" subnet.yaml | sed -e "s;{{SA_NAME}};$(SA_NAME);g" | sed -e "s;{{TS_ROUTES}};$(TS_ROUTES);g" diff --git a/docs/k8s/rolebinding.yaml b/docs/k8s/rolebinding.yaml index b32e66b98..3b18ba8d3 100644 --- a/docs/k8s/rolebinding.yaml +++ b/docs/k8s/rolebinding.yaml @@ -1,13 +1,13 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause -apiVersion: rbac.authorization.k8s.io/v1 -kind: RoleBinding -metadata: - name: tailscale -subjects: -- kind: ServiceAccount - name: "{{SA_NAME}}" -roleRef: - kind: Role - name: tailscale - apiGroup: rbac.authorization.k8s.io +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: tailscale +subjects: +- kind: ServiceAccount + name: "{{SA_NAME}}" +roleRef: + kind: Role + name: tailscale + apiGroup: rbac.authorization.k8s.io diff --git a/docs/k8s/sa.yaml b/docs/k8s/sa.yaml index 85b56bd24..edd3944ba 100644 --- a/docs/k8s/sa.yaml +++ b/docs/k8s/sa.yaml @@ -1,6 +1,6 @@ -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause -apiVersion: v1 -kind: ServiceAccount -metadata: - name: {{SA_NAME}} +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{SA_NAME}} diff --git a/docs/sysv/tailscale.init b/docs/sysv/tailscale.init index fc22088b1..ca21033df 100755 --- a/docs/sysv/tailscale.init +++ b/docs/sysv/tailscale.init @@ -1,63 +1,63 @@ -#!/bin/sh -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -### BEGIN INIT INFO -# Provides: tailscaled -# Required-Start: -# Required-Stop: -# Default-Start: -# Default-Stop: -# Short-Description: Tailscale Mesh Wireguard VPN -### END INIT INFO - -set -e - -# /etc/init.d/tailscale: start and stop the Tailscale VPN service - -test -x /usr/sbin/tailscaled || exit 0 - -umask 022 - -. /lib/lsb/init-functions - -# Are we running from init? -run_by_init() { - ([ "$previous" ] && [ "$runlevel" ]) || [ "$runlevel" = S ] -} - -export PATH="${PATH:+$PATH:}/usr/sbin:/sbin" - -case "$1" in - start) - log_daemon_msg "Starting Tailscale VPN" "tailscaled" || true - if start-stop-daemon --start --oknodo --name tailscaled -m --pidfile /run/tailscaled.pid --background \ - --exec /usr/sbin/tailscaled -- \ - --state=/var/lib/tailscale/tailscaled.state \ - --socket=/run/tailscale/tailscaled.sock \ - --port 41641; - then - log_end_msg 0 || true - else - log_end_msg 1 || true - fi - ;; - stop) - log_daemon_msg "Stopping Tailscale VPN" "tailscaled" || true - if start-stop-daemon --stop --remove-pidfile --pidfile /run/tailscaled.pid --exec /usr/sbin/tailscaled; then - log_end_msg 0 || true - else - log_end_msg 1 || true - fi - ;; - - status) - status_of_proc -p /run/tailscaled.pid /usr/sbin/tailscaled tailscaled && exit 0 || exit $? - ;; - - *) - log_action_msg "Usage: /etc/init.d/tailscaled {start|stop|status}" || true - exit 1 -esac - -exit 0 +#!/bin/sh +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +### BEGIN INIT INFO +# Provides: tailscaled +# Required-Start: +# Required-Stop: +# Default-Start: +# Default-Stop: +# Short-Description: Tailscale Mesh Wireguard VPN +### END INIT INFO + +set -e + +# /etc/init.d/tailscale: start and stop the Tailscale VPN service + +test -x /usr/sbin/tailscaled || exit 0 + +umask 022 + +. /lib/lsb/init-functions + +# Are we running from init? +run_by_init() { + ([ "$previous" ] && [ "$runlevel" ]) || [ "$runlevel" = S ] +} + +export PATH="${PATH:+$PATH:}/usr/sbin:/sbin" + +case "$1" in + start) + log_daemon_msg "Starting Tailscale VPN" "tailscaled" || true + if start-stop-daemon --start --oknodo --name tailscaled -m --pidfile /run/tailscaled.pid --background \ + --exec /usr/sbin/tailscaled -- \ + --state=/var/lib/tailscale/tailscaled.state \ + --socket=/run/tailscale/tailscaled.sock \ + --port 41641; + then + log_end_msg 0 || true + else + log_end_msg 1 || true + fi + ;; + stop) + log_daemon_msg "Stopping Tailscale VPN" "tailscaled" || true + if start-stop-daemon --stop --remove-pidfile --pidfile /run/tailscaled.pid --exec /usr/sbin/tailscaled; then + log_end_msg 0 || true + else + log_end_msg 1 || true + fi + ;; + + status) + status_of_proc -p /run/tailscaled.pid /usr/sbin/tailscaled tailscaled && exit 0 || exit $? + ;; + + *) + log_action_msg "Usage: /etc/init.d/tailscaled {start|stop|status}" || true + exit 1 +esac + +exit 0 diff --git a/doctor/doctor.go b/doctor/doctor.go index 96af39f5f..7c3047e12 100644 --- a/doctor/doctor.go +++ b/doctor/doctor.go @@ -1,79 +1,79 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package doctor contains more in-depth healthchecks that can be run to aid in -// diagnosing Tailscale issues. -package doctor - -import ( - "context" - "sync" - - "tailscale.com/types/logger" -) - -// Check is the interface defining a singular check. -// -// A check should log information that it gathers using the provided log -// function, and should attempt to make as much progress as possible in error -// conditions. -type Check interface { - // Name should return a name describing this check, in lower-kebab-case - // (i.e. "my-check", not "MyCheck" or "my_check"). - Name() string - // Run executes the check, logging diagnostic information to the - // provided logger function. - Run(context.Context, logger.Logf) error -} - -// RunChecks runs a list of checks in parallel, and logs any returned errors -// after all checks have returned. -func RunChecks(ctx context.Context, log logger.Logf, checks ...Check) { - if len(checks) == 0 { - return - } - - type namedErr struct { - name string - err error - } - errs := make(chan namedErr, len(checks)) - - var wg sync.WaitGroup - wg.Add(len(checks)) - for _, check := range checks { - go func(c Check) { - defer wg.Done() - - plog := logger.WithPrefix(log, c.Name()+": ") - errs <- namedErr{ - name: c.Name(), - err: c.Run(ctx, plog), - } - }(check) - } - - wg.Wait() - close(errs) - - for n := range errs { - if n.err == nil { - continue - } - - log("check %s: %v", n.name, n.err) - } -} - -// CheckFunc creates a Check from a name and a function. -func CheckFunc(name string, run func(context.Context, logger.Logf) error) Check { - return checkFunc{name, run} -} - -type checkFunc struct { - name string - run func(context.Context, logger.Logf) error -} - -func (c checkFunc) Name() string { return c.name } -func (c checkFunc) Run(ctx context.Context, log logger.Logf) error { return c.run(ctx, log) } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package doctor contains more in-depth healthchecks that can be run to aid in +// diagnosing Tailscale issues. +package doctor + +import ( + "context" + "sync" + + "tailscale.com/types/logger" +) + +// Check is the interface defining a singular check. +// +// A check should log information that it gathers using the provided log +// function, and should attempt to make as much progress as possible in error +// conditions. +type Check interface { + // Name should return a name describing this check, in lower-kebab-case + // (i.e. "my-check", not "MyCheck" or "my_check"). + Name() string + // Run executes the check, logging diagnostic information to the + // provided logger function. + Run(context.Context, logger.Logf) error +} + +// RunChecks runs a list of checks in parallel, and logs any returned errors +// after all checks have returned. +func RunChecks(ctx context.Context, log logger.Logf, checks ...Check) { + if len(checks) == 0 { + return + } + + type namedErr struct { + name string + err error + } + errs := make(chan namedErr, len(checks)) + + var wg sync.WaitGroup + wg.Add(len(checks)) + for _, check := range checks { + go func(c Check) { + defer wg.Done() + + plog := logger.WithPrefix(log, c.Name()+": ") + errs <- namedErr{ + name: c.Name(), + err: c.Run(ctx, plog), + } + }(check) + } + + wg.Wait() + close(errs) + + for n := range errs { + if n.err == nil { + continue + } + + log("check %s: %v", n.name, n.err) + } +} + +// CheckFunc creates a Check from a name and a function. +func CheckFunc(name string, run func(context.Context, logger.Logf) error) Check { + return checkFunc{name, run} +} + +type checkFunc struct { + name string + run func(context.Context, logger.Logf) error +} + +func (c checkFunc) Name() string { return c.name } +func (c checkFunc) Run(ctx context.Context, log logger.Logf) error { return c.run(ctx, log) } diff --git a/doctor/doctor_test.go b/doctor/doctor_test.go index dab7afa38..87250f10e 100644 --- a/doctor/doctor_test.go +++ b/doctor/doctor_test.go @@ -1,49 +1,49 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package doctor - -import ( - "context" - "fmt" - "sync" - "testing" - - qt "github.com/frankban/quicktest" - "tailscale.com/types/logger" -) - -func TestRunChecks(t *testing.T) { - c := qt.New(t) - var ( - mu sync.Mutex - lines []string - ) - logf := func(format string, args ...any) { - mu.Lock() - defer mu.Unlock() - lines = append(lines, fmt.Sprintf(format, args...)) - } - - ctx := context.Background() - RunChecks(ctx, logf, - testCheck1{}, - CheckFunc("testcheck2", func(_ context.Context, log logger.Logf) error { - log("check 2") - return nil - }), - ) - - mu.Lock() - defer mu.Unlock() - c.Assert(lines, qt.Contains, "testcheck1: check 1") - c.Assert(lines, qt.Contains, "testcheck2: check 2") -} - -type testCheck1 struct{} - -func (t testCheck1) Name() string { return "testcheck1" } -func (t testCheck1) Run(_ context.Context, log logger.Logf) error { - log("check 1") - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package doctor + +import ( + "context" + "fmt" + "sync" + "testing" + + qt "github.com/frankban/quicktest" + "tailscale.com/types/logger" +) + +func TestRunChecks(t *testing.T) { + c := qt.New(t) + var ( + mu sync.Mutex + lines []string + ) + logf := func(format string, args ...any) { + mu.Lock() + defer mu.Unlock() + lines = append(lines, fmt.Sprintf(format, args...)) + } + + ctx := context.Background() + RunChecks(ctx, logf, + testCheck1{}, + CheckFunc("testcheck2", func(_ context.Context, log logger.Logf) error { + log("check 2") + return nil + }), + ) + + mu.Lock() + defer mu.Unlock() + c.Assert(lines, qt.Contains, "testcheck1: check 1") + c.Assert(lines, qt.Contains, "testcheck2: check 2") +} + +type testCheck1 struct{} + +func (t testCheck1) Name() string { return "testcheck1" } +func (t testCheck1) Run(_ context.Context, log logger.Logf) error { + log("check 1") + return nil +} diff --git a/doctor/permissions/permissions_bsd.go b/doctor/permissions/permissions_bsd.go index 4031af722..8b034cfff 100644 --- a/doctor/permissions/permissions_bsd.go +++ b/doctor/permissions/permissions_bsd.go @@ -1,23 +1,23 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin || freebsd || openbsd - -package permissions - -import ( - "golang.org/x/sys/unix" - "tailscale.com/types/logger" -) - -func permissionsImpl(logf logger.Logf) error { - groups, _ := unix.Getgroups() - logf("uid=%s euid=%s gid=%s egid=%s groups=%s", - formatUserID(unix.Getuid()), - formatUserID(unix.Geteuid()), - formatGroupID(unix.Getgid()), - formatGroupID(unix.Getegid()), - formatGroups(groups), - ) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin || freebsd || openbsd + +package permissions + +import ( + "golang.org/x/sys/unix" + "tailscale.com/types/logger" +) + +func permissionsImpl(logf logger.Logf) error { + groups, _ := unix.Getgroups() + logf("uid=%s euid=%s gid=%s egid=%s groups=%s", + formatUserID(unix.Getuid()), + formatUserID(unix.Geteuid()), + formatGroupID(unix.Getgid()), + formatGroupID(unix.Getegid()), + formatGroups(groups), + ) + return nil +} diff --git a/doctor/permissions/permissions_linux.go b/doctor/permissions/permissions_linux.go index ef0a97056..12bb393d5 100644 --- a/doctor/permissions/permissions_linux.go +++ b/doctor/permissions/permissions_linux.go @@ -1,62 +1,62 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package permissions - -import ( - "fmt" - "strings" - "unsafe" - - "golang.org/x/sys/unix" - "tailscale.com/types/logger" -) - -func permissionsImpl(logf logger.Logf) error { - // NOTE: getresuid and getresgid never fail unless passed an - // invalid address. - var ruid, euid, suid uint64 - unix.Syscall(unix.SYS_GETRESUID, - uintptr(unsafe.Pointer(&ruid)), - uintptr(unsafe.Pointer(&euid)), - uintptr(unsafe.Pointer(&suid)), - ) - - var rgid, egid, sgid uint64 - unix.Syscall(unix.SYS_GETRESGID, - uintptr(unsafe.Pointer(&rgid)), - uintptr(unsafe.Pointer(&egid)), - uintptr(unsafe.Pointer(&sgid)), - ) - - groups, _ := unix.Getgroups() - - var buf strings.Builder - fmt.Fprintf(&buf, "ruid=%s euid=%s suid=%s rgid=%s egid=%s sgid=%s groups=%s", - formatUserID(ruid), formatUserID(euid), formatUserID(suid), - formatGroupID(rgid), formatGroupID(egid), formatGroupID(sgid), - formatGroups(groups), - ) - - // Get process capabilities - var ( - capHeader = unix.CapUserHeader{ - Version: unix.LINUX_CAPABILITY_VERSION_3, - Pid: 0, // 0 means 'ourselves' - } - capData unix.CapUserData - ) - - if err := unix.Capget(&capHeader, &capData); err != nil { - fmt.Fprintf(&buf, " caperr=%v", err) - } else { - fmt.Fprintf(&buf, " cap_effective=%08x cap_permitted=%08x cap_inheritable=%08x", - capData.Effective, capData.Permitted, capData.Inheritable, - ) - } - - logf("%s", buf.String()) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package permissions + +import ( + "fmt" + "strings" + "unsafe" + + "golang.org/x/sys/unix" + "tailscale.com/types/logger" +) + +func permissionsImpl(logf logger.Logf) error { + // NOTE: getresuid and getresgid never fail unless passed an + // invalid address. + var ruid, euid, suid uint64 + unix.Syscall(unix.SYS_GETRESUID, + uintptr(unsafe.Pointer(&ruid)), + uintptr(unsafe.Pointer(&euid)), + uintptr(unsafe.Pointer(&suid)), + ) + + var rgid, egid, sgid uint64 + unix.Syscall(unix.SYS_GETRESGID, + uintptr(unsafe.Pointer(&rgid)), + uintptr(unsafe.Pointer(&egid)), + uintptr(unsafe.Pointer(&sgid)), + ) + + groups, _ := unix.Getgroups() + + var buf strings.Builder + fmt.Fprintf(&buf, "ruid=%s euid=%s suid=%s rgid=%s egid=%s sgid=%s groups=%s", + formatUserID(ruid), formatUserID(euid), formatUserID(suid), + formatGroupID(rgid), formatGroupID(egid), formatGroupID(sgid), + formatGroups(groups), + ) + + // Get process capabilities + var ( + capHeader = unix.CapUserHeader{ + Version: unix.LINUX_CAPABILITY_VERSION_3, + Pid: 0, // 0 means 'ourselves' + } + capData unix.CapUserData + ) + + if err := unix.Capget(&capHeader, &capData); err != nil { + fmt.Fprintf(&buf, " caperr=%v", err) + } else { + fmt.Fprintf(&buf, " cap_effective=%08x cap_permitted=%08x cap_inheritable=%08x", + capData.Effective, capData.Permitted, capData.Inheritable, + ) + } + + logf("%s", buf.String()) + return nil +} diff --git a/doctor/permissions/permissions_other.go b/doctor/permissions/permissions_other.go index 5e310b98e..7e6912b49 100644 --- a/doctor/permissions/permissions_other.go +++ b/doctor/permissions/permissions_other.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !(linux || darwin || freebsd || openbsd) - -package permissions - -import ( - "runtime" - - "tailscale.com/types/logger" -) - -func permissionsImpl(logf logger.Logf) error { - logf("unsupported on %s/%s", runtime.GOOS, runtime.GOARCH) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !(linux || darwin || freebsd || openbsd) + +package permissions + +import ( + "runtime" + + "tailscale.com/types/logger" +) + +func permissionsImpl(logf logger.Logf) error { + logf("unsupported on %s/%s", runtime.GOOS, runtime.GOARCH) + return nil +} diff --git a/doctor/permissions/permissions_test.go b/doctor/permissions/permissions_test.go index 9b71c3be1..941d406ef 100644 --- a/doctor/permissions/permissions_test.go +++ b/doctor/permissions/permissions_test.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package permissions - -import "testing" - -func TestPermissionsImpl(t *testing.T) { - if err := permissionsImpl(t.Logf); err != nil { - t.Error(err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package permissions + +import "testing" + +func TestPermissionsImpl(t *testing.T) { + if err := permissionsImpl(t.Logf); err != nil { + t.Error(err) + } +} diff --git a/doctor/routetable/routetable.go b/doctor/routetable/routetable.go index 1ebf294ce..76e4ef949 100644 --- a/doctor/routetable/routetable.go +++ b/doctor/routetable/routetable.go @@ -1,34 +1,34 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package routetable provides a doctor.Check that dumps the current system's -// route table to the log. -package routetable - -import ( - "context" - - "tailscale.com/net/routetable" - "tailscale.com/types/logger" -) - -// MaxRoutes is the maximum number of routes that will be displayed. -const MaxRoutes = 1000 - -// Check implements the doctor.Check interface. -type Check struct{} - -func (Check) Name() string { - return "routetable" -} - -func (Check) Run(_ context.Context, logf logger.Logf) error { - rs, err := routetable.Get(MaxRoutes) - if err != nil { - return err - } - for _, r := range rs { - logf("%s", r) - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package routetable provides a doctor.Check that dumps the current system's +// route table to the log. +package routetable + +import ( + "context" + + "tailscale.com/net/routetable" + "tailscale.com/types/logger" +) + +// MaxRoutes is the maximum number of routes that will be displayed. +const MaxRoutes = 1000 + +// Check implements the doctor.Check interface. +type Check struct{} + +func (Check) Name() string { + return "routetable" +} + +func (Check) Run(_ context.Context, logf logger.Logf) error { + rs, err := routetable.Get(MaxRoutes) + if err != nil { + return err + } + for _, r := range rs { + logf("%s", r) + } + return nil +} diff --git a/envknob/envknob_nottest.go b/envknob/envknob_nottest.go index b21266f13..0dd900cc8 100644 --- a/envknob/envknob_nottest.go +++ b/envknob/envknob_nottest.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ts_not_in_tests - -package envknob - -import "runtime" - -func GOOS() string { - // When the "ts_not_in_tests" build tag is used, we define this func to just - // return a simple constant so callers optimize just as if the knob were not - // present. We can then build production/optimized builds with the - // "ts_not_in_tests" build tag. - return runtime.GOOS -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_not_in_tests + +package envknob + +import "runtime" + +func GOOS() string { + // When the "ts_not_in_tests" build tag is used, we define this func to just + // return a simple constant so callers optimize just as if the knob were not + // present. We can then build production/optimized builds with the + // "ts_not_in_tests" build tag. + return runtime.GOOS +} diff --git a/envknob/envknob_testable.go b/envknob/envknob_testable.go index 53687d732..e7f038336 100644 --- a/envknob/envknob_testable.go +++ b/envknob/envknob_testable.go @@ -1,23 +1,23 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !ts_not_in_tests - -package envknob - -import "runtime" - -// GOOS reports the effective runtime.GOOS to run as. -// -// In practice this returns just runtime.GOOS, unless overridden by -// test TS_DEBUG_FAKE_GOOS. -// -// This allows changing OS-specific stuff like the IPN server behavior -// for tests so we can e.g. test Windows-specific behaviors on Linux. -// This isn't universally used. -func GOOS() string { - if v := String("TS_DEBUG_FAKE_GOOS"); v != "" { - return v - } - return runtime.GOOS -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_not_in_tests + +package envknob + +import "runtime" + +// GOOS reports the effective runtime.GOOS to run as. +// +// In practice this returns just runtime.GOOS, unless overridden by +// test TS_DEBUG_FAKE_GOOS. +// +// This allows changing OS-specific stuff like the IPN server behavior +// for tests so we can e.g. test Windows-specific behaviors on Linux. +// This isn't universally used. +func GOOS() string { + if v := String("TS_DEBUG_FAKE_GOOS"); v != "" { + return v + } + return runtime.GOOS +} diff --git a/envknob/logknob/logknob.go b/envknob/logknob/logknob.go index a7b0a05e8..350384b86 100644 --- a/envknob/logknob/logknob.go +++ b/envknob/logknob/logknob.go @@ -1,85 +1,85 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package logknob provides a helpful wrapper that allows enabling logging -// based on either an envknob or other methods of enablement. -package logknob - -import ( - "sync/atomic" - - "tailscale.com/envknob" - "tailscale.com/tailcfg" - "tailscale.com/types/logger" - "tailscale.com/types/views" -) - -// TODO(andrew-d): should we have a package-global registry of logknobs? It -// would allow us to update from a netmap in a central location, which might be -// reason enough to do it... - -// LogKnob allows configuring verbose logging, with multiple ways to enable. It -// supports enabling logging via envknob, via atomic boolean (for use in e.g. -// c2n log level changes), and via capabilities from a NetMap (so users can -// enable logging via the ACL JSON). -type LogKnob struct { - capName tailcfg.NodeCapability - cap atomic.Bool - env func() bool - manual atomic.Bool -} - -// NewLogKnob creates a new LogKnob, with the provided environment variable -// name and/or NetMap capability. -func NewLogKnob(env string, cap tailcfg.NodeCapability) *LogKnob { - if env == "" && cap == "" { - panic("must provide either an environment variable or capability") - } - - lk := &LogKnob{ - capName: cap, - } - if env != "" { - lk.env = envknob.RegisterBool(env) - } else { - lk.env = func() bool { return false } - } - return lk -} - -// Set will cause logs to be printed when called with Set(true). When called -// with Set(false), logs will not be printed due to an earlier call of -// Set(true), but may be printed due to either the envknob and/or capability of -// this LogKnob. -func (lk *LogKnob) Set(v bool) { - lk.manual.Store(v) -} - -// NetMap is an interface for the parts of netmap.NetworkMap that we care -// about; we use this rather than a concrete type to avoid a circular -// dependency. -type NetMap interface { - SelfCapabilities() views.Slice[tailcfg.NodeCapability] -} - -// UpdateFromNetMap will enable logging if the SelfNode in the provided NetMap -// contains the capability provided for this LogKnob. -func (lk *LogKnob) UpdateFromNetMap(nm NetMap) { - if lk.capName == "" { - return - } - - lk.cap.Store(views.SliceContains(nm.SelfCapabilities(), lk.capName)) -} - -// Do will call log with the provided format and arguments if any of the -// configured methods for enabling logging are true. -func (lk *LogKnob) Do(log logger.Logf, format string, args ...any) { - if lk.shouldLog() { - log(format, args...) - } -} - -func (lk *LogKnob) shouldLog() bool { - return lk.manual.Load() || lk.env() || lk.cap.Load() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package logknob provides a helpful wrapper that allows enabling logging +// based on either an envknob or other methods of enablement. +package logknob + +import ( + "sync/atomic" + + "tailscale.com/envknob" + "tailscale.com/tailcfg" + "tailscale.com/types/logger" + "tailscale.com/types/views" +) + +// TODO(andrew-d): should we have a package-global registry of logknobs? It +// would allow us to update from a netmap in a central location, which might be +// reason enough to do it... + +// LogKnob allows configuring verbose logging, with multiple ways to enable. It +// supports enabling logging via envknob, via atomic boolean (for use in e.g. +// c2n log level changes), and via capabilities from a NetMap (so users can +// enable logging via the ACL JSON). +type LogKnob struct { + capName tailcfg.NodeCapability + cap atomic.Bool + env func() bool + manual atomic.Bool +} + +// NewLogKnob creates a new LogKnob, with the provided environment variable +// name and/or NetMap capability. +func NewLogKnob(env string, cap tailcfg.NodeCapability) *LogKnob { + if env == "" && cap == "" { + panic("must provide either an environment variable or capability") + } + + lk := &LogKnob{ + capName: cap, + } + if env != "" { + lk.env = envknob.RegisterBool(env) + } else { + lk.env = func() bool { return false } + } + return lk +} + +// Set will cause logs to be printed when called with Set(true). When called +// with Set(false), logs will not be printed due to an earlier call of +// Set(true), but may be printed due to either the envknob and/or capability of +// this LogKnob. +func (lk *LogKnob) Set(v bool) { + lk.manual.Store(v) +} + +// NetMap is an interface for the parts of netmap.NetworkMap that we care +// about; we use this rather than a concrete type to avoid a circular +// dependency. +type NetMap interface { + SelfCapabilities() views.Slice[tailcfg.NodeCapability] +} + +// UpdateFromNetMap will enable logging if the SelfNode in the provided NetMap +// contains the capability provided for this LogKnob. +func (lk *LogKnob) UpdateFromNetMap(nm NetMap) { + if lk.capName == "" { + return + } + + lk.cap.Store(views.SliceContains(nm.SelfCapabilities(), lk.capName)) +} + +// Do will call log with the provided format and arguments if any of the +// configured methods for enabling logging are true. +func (lk *LogKnob) Do(log logger.Logf, format string, args ...any) { + if lk.shouldLog() { + log(format, args...) + } +} + +func (lk *LogKnob) shouldLog() bool { + return lk.manual.Load() || lk.env() || lk.cap.Load() +} diff --git a/envknob/logknob/logknob_test.go b/envknob/logknob/logknob_test.go index c9eed5612..b2a376a25 100644 --- a/envknob/logknob/logknob_test.go +++ b/envknob/logknob/logknob_test.go @@ -1,102 +1,102 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package logknob - -import ( - "bytes" - "fmt" - "testing" - - "tailscale.com/envknob" - "tailscale.com/tailcfg" - "tailscale.com/types/netmap" -) - -var testKnob = NewLogKnob( - "TS_TEST_LOGKNOB", - "https://tailscale.com/cap/testing", -) - -// Static type assertion for our interface type. -var _ NetMap = &netmap.NetworkMap{} - -func TestLogKnob(t *testing.T) { - t.Run("Default", func(t *testing.T) { - if testKnob.shouldLog() { - t.Errorf("expected default shouldLog()=false") - } - assertNoLogs(t) - }) - t.Run("Manual", func(t *testing.T) { - t.Cleanup(func() { testKnob.Set(false) }) - - assertNoLogs(t) - testKnob.Set(true) - if !testKnob.shouldLog() { - t.Errorf("expected shouldLog()=true") - } - assertLogs(t) - }) - t.Run("Env", func(t *testing.T) { - t.Cleanup(func() { - envknob.Setenv("TS_TEST_LOGKNOB", "") - }) - - assertNoLogs(t) - if testKnob.shouldLog() { - t.Errorf("expected default shouldLog()=false") - } - - envknob.Setenv("TS_TEST_LOGKNOB", "true") - if !testKnob.shouldLog() { - t.Errorf("expected shouldLog()=true") - } - assertLogs(t) - }) - t.Run("NetMap", func(t *testing.T) { - t.Cleanup(func() { testKnob.cap.Store(false) }) - - assertNoLogs(t) - if testKnob.shouldLog() { - t.Errorf("expected default shouldLog()=false") - } - - testKnob.UpdateFromNetMap(&netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - Capabilities: []tailcfg.NodeCapability{ - "https://tailscale.com/cap/testing", - }, - }).View(), - }) - if !testKnob.shouldLog() { - t.Errorf("expected shouldLog()=true") - } - assertLogs(t) - }) -} - -func assertLogs(t *testing.T) { - var buf bytes.Buffer - logf := func(format string, args ...any) { - fmt.Fprintf(&buf, format, args...) - } - - testKnob.Do(logf, "hello %s", "world") - const want = "hello world" - if got := buf.String(); got != want { - t.Errorf("got %q, want %q", got, want) - } -} - -func assertNoLogs(t *testing.T) { - var buf bytes.Buffer - logf := func(format string, args ...any) { - fmt.Fprintf(&buf, format, args...) - } - - testKnob.Do(logf, "hello %s", "world") - if got := buf.String(); got != "" { - t.Errorf("expected no logs, but got: %q", got) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package logknob + +import ( + "bytes" + "fmt" + "testing" + + "tailscale.com/envknob" + "tailscale.com/tailcfg" + "tailscale.com/types/netmap" +) + +var testKnob = NewLogKnob( + "TS_TEST_LOGKNOB", + "https://tailscale.com/cap/testing", +) + +// Static type assertion for our interface type. +var _ NetMap = &netmap.NetworkMap{} + +func TestLogKnob(t *testing.T) { + t.Run("Default", func(t *testing.T) { + if testKnob.shouldLog() { + t.Errorf("expected default shouldLog()=false") + } + assertNoLogs(t) + }) + t.Run("Manual", func(t *testing.T) { + t.Cleanup(func() { testKnob.Set(false) }) + + assertNoLogs(t) + testKnob.Set(true) + if !testKnob.shouldLog() { + t.Errorf("expected shouldLog()=true") + } + assertLogs(t) + }) + t.Run("Env", func(t *testing.T) { + t.Cleanup(func() { + envknob.Setenv("TS_TEST_LOGKNOB", "") + }) + + assertNoLogs(t) + if testKnob.shouldLog() { + t.Errorf("expected default shouldLog()=false") + } + + envknob.Setenv("TS_TEST_LOGKNOB", "true") + if !testKnob.shouldLog() { + t.Errorf("expected shouldLog()=true") + } + assertLogs(t) + }) + t.Run("NetMap", func(t *testing.T) { + t.Cleanup(func() { testKnob.cap.Store(false) }) + + assertNoLogs(t) + if testKnob.shouldLog() { + t.Errorf("expected default shouldLog()=false") + } + + testKnob.UpdateFromNetMap(&netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Capabilities: []tailcfg.NodeCapability{ + "https://tailscale.com/cap/testing", + }, + }).View(), + }) + if !testKnob.shouldLog() { + t.Errorf("expected shouldLog()=true") + } + assertLogs(t) + }) +} + +func assertLogs(t *testing.T) { + var buf bytes.Buffer + logf := func(format string, args ...any) { + fmt.Fprintf(&buf, format, args...) + } + + testKnob.Do(logf, "hello %s", "world") + const want = "hello world" + if got := buf.String(); got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func assertNoLogs(t *testing.T) { + var buf bytes.Buffer + logf := func(format string, args ...any) { + fmt.Fprintf(&buf, format, args...) + } + + testKnob.Do(logf, "hello %s", "world") + if got := buf.String(); got != "" { + t.Errorf("expected no logs, but got: %q", got) + } +} diff --git a/gomod_test.go b/gomod_test.go index 52fdd4639..f984b5d6f 100644 --- a/gomod_test.go +++ b/gomod_test.go @@ -1,25 +1,25 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tailscaleroot - -import ( - "os" - "testing" - - "golang.org/x/mod/modfile" -) - -func TestGoMod(t *testing.T) { - goMod, err := os.ReadFile("go.mod") - if err != nil { - t.Fatal(err) - } - f, err := modfile.Parse("go.mod", goMod, nil) - if err != nil { - t.Fatal(err) - } - if len(f.Replace) > 0 { - t.Errorf("go.mod has %d replace directives; expect zero in this repo", len(f.Replace)) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailscaleroot + +import ( + "os" + "testing" + + "golang.org/x/mod/modfile" +) + +func TestGoMod(t *testing.T) { + goMod, err := os.ReadFile("go.mod") + if err != nil { + t.Fatal(err) + } + f, err := modfile.Parse("go.mod", goMod, nil) + if err != nil { + t.Fatal(err) + } + if len(f.Replace) > 0 { + t.Errorf("go.mod has %d replace directives; expect zero in this repo", len(f.Replace)) + } +} diff --git a/hostinfo/hostinfo_darwin.go b/hostinfo/hostinfo_darwin.go index a61d95b32..0b1774e77 100644 --- a/hostinfo/hostinfo_darwin.go +++ b/hostinfo/hostinfo_darwin.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin - -package hostinfo - -import ( - "os" - "path/filepath" -) - -func init() { - packageType = packageTypeDarwin -} - -func packageTypeDarwin() string { - // Using tailscaled or IPNExtension? - exe, _ := os.Executable() - return filepath.Base(exe) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin + +package hostinfo + +import ( + "os" + "path/filepath" +) + +func init() { + packageType = packageTypeDarwin +} + +func packageTypeDarwin() string { + // Using tailscaled or IPNExtension? + exe, _ := os.Executable() + return filepath.Base(exe) +} diff --git a/hostinfo/hostinfo_freebsd.go b/hostinfo/hostinfo_freebsd.go index 15c7783aa..3661b1322 100644 --- a/hostinfo/hostinfo_freebsd.go +++ b/hostinfo/hostinfo_freebsd.go @@ -1,64 +1,64 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build freebsd - -package hostinfo - -import ( - "bytes" - "os" - "os/exec" - - "golang.org/x/sys/unix" - "tailscale.com/types/ptr" - "tailscale.com/version/distro" -) - -func init() { - osVersion = lazyOSVersion.Get - distroName = distroNameFreeBSD - distroVersion = distroVersionFreeBSD -} - -var ( - lazyVersionMeta = &lazyAtomicValue[versionMeta]{f: ptr.To(freebsdVersionMeta)} - lazyOSVersion = &lazyAtomicValue[string]{f: ptr.To(osVersionFreeBSD)} -) - -func distroNameFreeBSD() string { - return lazyVersionMeta.Get().DistroName -} - -func distroVersionFreeBSD() string { - return lazyVersionMeta.Get().DistroVersion -} - -type versionMeta struct { - DistroName string - DistroVersion string - DistroCodeName string -} - -func osVersionFreeBSD() string { - var un unix.Utsname - unix.Uname(&un) - return unix.ByteSliceToString(un.Release[:]) -} - -func freebsdVersionMeta() (meta versionMeta) { - d := distro.Get() - meta.DistroName = string(d) - switch d { - case distro.Pfsense: - b, _ := os.ReadFile("/etc/version") - meta.DistroVersion = string(bytes.TrimSpace(b)) - case distro.OPNsense: - b, _ := exec.Command("opnsense-version").Output() - meta.DistroVersion = string(bytes.TrimSpace(b)) - case distro.TrueNAS: - b, _ := os.ReadFile("/etc/version") - meta.DistroVersion = string(bytes.TrimSpace(b)) - } - return -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build freebsd + +package hostinfo + +import ( + "bytes" + "os" + "os/exec" + + "golang.org/x/sys/unix" + "tailscale.com/types/ptr" + "tailscale.com/version/distro" +) + +func init() { + osVersion = lazyOSVersion.Get + distroName = distroNameFreeBSD + distroVersion = distroVersionFreeBSD +} + +var ( + lazyVersionMeta = &lazyAtomicValue[versionMeta]{f: ptr.To(freebsdVersionMeta)} + lazyOSVersion = &lazyAtomicValue[string]{f: ptr.To(osVersionFreeBSD)} +) + +func distroNameFreeBSD() string { + return lazyVersionMeta.Get().DistroName +} + +func distroVersionFreeBSD() string { + return lazyVersionMeta.Get().DistroVersion +} + +type versionMeta struct { + DistroName string + DistroVersion string + DistroCodeName string +} + +func osVersionFreeBSD() string { + var un unix.Utsname + unix.Uname(&un) + return unix.ByteSliceToString(un.Release[:]) +} + +func freebsdVersionMeta() (meta versionMeta) { + d := distro.Get() + meta.DistroName = string(d) + switch d { + case distro.Pfsense: + b, _ := os.ReadFile("/etc/version") + meta.DistroVersion = string(bytes.TrimSpace(b)) + case distro.OPNsense: + b, _ := exec.Command("opnsense-version").Output() + meta.DistroVersion = string(bytes.TrimSpace(b)) + case distro.TrueNAS: + b, _ := os.ReadFile("/etc/version") + meta.DistroVersion = string(bytes.TrimSpace(b)) + } + return +} diff --git a/hostinfo/hostinfo_test.go b/hostinfo/hostinfo_test.go index 76282ebf5..9fe32e044 100644 --- a/hostinfo/hostinfo_test.go +++ b/hostinfo/hostinfo_test.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package hostinfo - -import ( - "encoding/json" - "strings" - "testing" -) - -func TestNew(t *testing.T) { - hi := New() - if hi == nil { - t.Fatal("no Hostinfo") - } - j, err := json.MarshalIndent(hi, " ", "") - if err != nil { - t.Fatal(err) - } - t.Logf("Got: %s", j) -} - -func TestOSVersion(t *testing.T) { - if osVersion == nil { - t.Skip("not available for OS") - } - t.Logf("Got: %#q", osVersion()) -} - -func TestEtcAptSourceFileIsDisabled(t *testing.T) { - tests := []struct { - name string - in string - want bool - }{ - {"empty", "", false}, - {"normal", "deb foo\n", false}, - {"normal-commented", "# deb foo\n", false}, - {"normal-disabled-by-ubuntu", "# deb foo # disabled on upgrade to dingus\n", true}, - {"normal-disabled-then-uncommented", "deb foo # disabled on upgrade to dingus\n", false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := etcAptSourceFileIsDisabled(strings.NewReader(tt.in)) - if got != tt.want { - t.Errorf("got %v; want %v", got, tt.want) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package hostinfo + +import ( + "encoding/json" + "strings" + "testing" +) + +func TestNew(t *testing.T) { + hi := New() + if hi == nil { + t.Fatal("no Hostinfo") + } + j, err := json.MarshalIndent(hi, " ", "") + if err != nil { + t.Fatal(err) + } + t.Logf("Got: %s", j) +} + +func TestOSVersion(t *testing.T) { + if osVersion == nil { + t.Skip("not available for OS") + } + t.Logf("Got: %#q", osVersion()) +} + +func TestEtcAptSourceFileIsDisabled(t *testing.T) { + tests := []struct { + name string + in string + want bool + }{ + {"empty", "", false}, + {"normal", "deb foo\n", false}, + {"normal-commented", "# deb foo\n", false}, + {"normal-disabled-by-ubuntu", "# deb foo # disabled on upgrade to dingus\n", true}, + {"normal-disabled-then-uncommented", "deb foo # disabled on upgrade to dingus\n", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := etcAptSourceFileIsDisabled(strings.NewReader(tt.in)) + if got != tt.want { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } +} diff --git a/hostinfo/hostinfo_uname.go b/hostinfo/hostinfo_uname.go index 10995c1c7..32b733a03 100644 --- a/hostinfo/hostinfo_uname.go +++ b/hostinfo/hostinfo_uname.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux || freebsd || openbsd || darwin - -package hostinfo - -import ( - "runtime" - - "golang.org/x/sys/unix" - "tailscale.com/types/ptr" -) - -func init() { - unameMachine = lazyUnameMachine.Get -} - -var lazyUnameMachine = &lazyAtomicValue[string]{f: ptr.To(unameMachineUnix)} - -func unameMachineUnix() string { - switch runtime.GOOS { - case "android": - // Don't call on Android for now. We're late in the 1.36 release cycle - // and don't want to test syscall filters on various Android versions to - // see what's permitted. Notably, the hostinfo_linux.go file has build - // tag !android, so maybe Uname is verboten. - return "" - case "ios": - // For similar reasons, don't call on iOS. There aren't many iOS devices - // and we know their CPU properties so calling this is only risk and no - // reward. - return "" - } - var un unix.Utsname - unix.Uname(&un) - return unix.ByteSliceToString(un.Machine[:]) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux || freebsd || openbsd || darwin + +package hostinfo + +import ( + "runtime" + + "golang.org/x/sys/unix" + "tailscale.com/types/ptr" +) + +func init() { + unameMachine = lazyUnameMachine.Get +} + +var lazyUnameMachine = &lazyAtomicValue[string]{f: ptr.To(unameMachineUnix)} + +func unameMachineUnix() string { + switch runtime.GOOS { + case "android": + // Don't call on Android for now. We're late in the 1.36 release cycle + // and don't want to test syscall filters on various Android versions to + // see what's permitted. Notably, the hostinfo_linux.go file has build + // tag !android, so maybe Uname is verboten. + return "" + case "ios": + // For similar reasons, don't call on iOS. There aren't many iOS devices + // and we know their CPU properties so calling this is only risk and no + // reward. + return "" + } + var un unix.Utsname + unix.Uname(&un) + return unix.ByteSliceToString(un.Machine[:]) +} diff --git a/hostinfo/wol.go b/hostinfo/wol.go index b6fc81a8b..3a30af2fe 100644 --- a/hostinfo/wol.go +++ b/hostinfo/wol.go @@ -1,106 +1,106 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package hostinfo - -import ( - "log" - "net" - "runtime" - "strings" - "unicode" - - "tailscale.com/envknob" -) - -// TODO(bradfitz): this is all too simplistic and static. It needs to run -// continuously in response to netmon events (USB ethernet adapaters might get -// plugged in) and look for the media type/status/etc. Right now on macOS it -// still detects a half dozen "up" en0, en1, en2, en3 etc interfaces that don't -// have any media. We should only report the one that's actually connected. -// But it works for now (2023-10-05) for fleshing out the rest. - -var wakeMAC = envknob.RegisterString("TS_WAKE_MAC") // mac address, "false" or "auto". for https://github.com/tailscale/tailscale/issues/306 - -// getWoLMACs returns up to 10 MAC address of the local machine to send -// wake-on-LAN packets to in order to wake it up. The returned MACs are in -// lowercase hex colon-separated form ("xx:xx:xx:xx:xx:xx"). -// -// If TS_WAKE_MAC=auto, it tries to automatically find the MACs based on the OS -// type and interface properties. (TODO(bradfitz): incomplete) If TS_WAKE_MAC is -// set to a MAC address, that sole MAC address is returned. -func getWoLMACs() (macs []string) { - switch runtime.GOOS { - case "ios", "android": - return nil - } - if s := wakeMAC(); s != "" { - switch s { - case "auto": - ifs, _ := net.Interfaces() - for _, iface := range ifs { - if iface.Flags&net.FlagLoopback != 0 { - continue - } - if iface.Flags&net.FlagBroadcast == 0 || - iface.Flags&net.FlagRunning == 0 || - iface.Flags&net.FlagUp == 0 { - continue - } - if keepMAC(iface.Name, iface.HardwareAddr) { - macs = append(macs, iface.HardwareAddr.String()) - } - if len(macs) == 10 { - break - } - } - return macs - case "false", "off": // fast path before ParseMAC error - return nil - } - mac, err := net.ParseMAC(s) - if err != nil { - log.Printf("invalid MAC %q", s) - return nil - } - return []string{mac.String()} - } - return nil -} - -var ignoreWakeOUI = map[[3]byte]bool{ - {0x00, 0x15, 0x5d}: true, // Hyper-V - {0x00, 0x50, 0x56}: true, // VMware - {0x00, 0x1c, 0x14}: true, // VMware - {0x00, 0x05, 0x69}: true, // VMware - {0x00, 0x0c, 0x29}: true, // VMware - {0x00, 0x1c, 0x42}: true, // Parallels - {0x08, 0x00, 0x27}: true, // VirtualBox - {0x00, 0x21, 0xf6}: true, // VirtualBox - {0x00, 0x14, 0x4f}: true, // VirtualBox - {0x00, 0x0f, 0x4b}: true, // VirtualBox - {0x52, 0x54, 0x00}: true, // VirtualBox/Vagrant -} - -func keepMAC(ifName string, mac []byte) bool { - if len(mac) != 6 { - return false - } - base := strings.TrimRightFunc(ifName, unicode.IsNumber) - switch runtime.GOOS { - case "darwin": - switch base { - case "llw", "awdl", "utun", "bridge", "lo", "gif", "stf", "anpi", "ap": - return false - } - } - if mac[0] == 0x02 && mac[1] == 0x42 { - // Docker container. - return false - } - oui := [3]byte{mac[0], mac[1], mac[2]} - if ignoreWakeOUI[oui] { - return false - } - return true -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package hostinfo + +import ( + "log" + "net" + "runtime" + "strings" + "unicode" + + "tailscale.com/envknob" +) + +// TODO(bradfitz): this is all too simplistic and static. It needs to run +// continuously in response to netmon events (USB ethernet adapaters might get +// plugged in) and look for the media type/status/etc. Right now on macOS it +// still detects a half dozen "up" en0, en1, en2, en3 etc interfaces that don't +// have any media. We should only report the one that's actually connected. +// But it works for now (2023-10-05) for fleshing out the rest. + +var wakeMAC = envknob.RegisterString("TS_WAKE_MAC") // mac address, "false" or "auto". for https://github.com/tailscale/tailscale/issues/306 + +// getWoLMACs returns up to 10 MAC address of the local machine to send +// wake-on-LAN packets to in order to wake it up. The returned MACs are in +// lowercase hex colon-separated form ("xx:xx:xx:xx:xx:xx"). +// +// If TS_WAKE_MAC=auto, it tries to automatically find the MACs based on the OS +// type and interface properties. (TODO(bradfitz): incomplete) If TS_WAKE_MAC is +// set to a MAC address, that sole MAC address is returned. +func getWoLMACs() (macs []string) { + switch runtime.GOOS { + case "ios", "android": + return nil + } + if s := wakeMAC(); s != "" { + switch s { + case "auto": + ifs, _ := net.Interfaces() + for _, iface := range ifs { + if iface.Flags&net.FlagLoopback != 0 { + continue + } + if iface.Flags&net.FlagBroadcast == 0 || + iface.Flags&net.FlagRunning == 0 || + iface.Flags&net.FlagUp == 0 { + continue + } + if keepMAC(iface.Name, iface.HardwareAddr) { + macs = append(macs, iface.HardwareAddr.String()) + } + if len(macs) == 10 { + break + } + } + return macs + case "false", "off": // fast path before ParseMAC error + return nil + } + mac, err := net.ParseMAC(s) + if err != nil { + log.Printf("invalid MAC %q", s) + return nil + } + return []string{mac.String()} + } + return nil +} + +var ignoreWakeOUI = map[[3]byte]bool{ + {0x00, 0x15, 0x5d}: true, // Hyper-V + {0x00, 0x50, 0x56}: true, // VMware + {0x00, 0x1c, 0x14}: true, // VMware + {0x00, 0x05, 0x69}: true, // VMware + {0x00, 0x0c, 0x29}: true, // VMware + {0x00, 0x1c, 0x42}: true, // Parallels + {0x08, 0x00, 0x27}: true, // VirtualBox + {0x00, 0x21, 0xf6}: true, // VirtualBox + {0x00, 0x14, 0x4f}: true, // VirtualBox + {0x00, 0x0f, 0x4b}: true, // VirtualBox + {0x52, 0x54, 0x00}: true, // VirtualBox/Vagrant +} + +func keepMAC(ifName string, mac []byte) bool { + if len(mac) != 6 { + return false + } + base := strings.TrimRightFunc(ifName, unicode.IsNumber) + switch runtime.GOOS { + case "darwin": + switch base { + case "llw", "awdl", "utun", "bridge", "lo", "gif", "stf", "anpi", "ap": + return false + } + } + if mac[0] == 0x02 && mac[1] == 0x42 { + // Docker container. + return false + } + oui := [3]byte{mac[0], mac[1], mac[2]} + if ignoreWakeOUI[oui] { + return false + } + return true +} diff --git a/ipn/ipnlocal/breaktcp_darwin.go b/ipn/ipnlocal/breaktcp_darwin.go index 289e760e1..13566198c 100644 --- a/ipn/ipnlocal/breaktcp_darwin.go +++ b/ipn/ipnlocal/breaktcp_darwin.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipnlocal - -import ( - "log" - - "golang.org/x/sys/unix" -) - -func init() { - breakTCPConns = breakTCPConnsDarwin -} - -func breakTCPConnsDarwin() error { - var matched int - for fd := 0; fd < 1000; fd++ { - _, err := unix.GetsockoptTCPConnectionInfo(fd, unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO) - if err == nil { - matched++ - err = unix.Close(fd) - log.Printf("debug: closed TCP fd %v: %v", fd, err) - } - } - if matched == 0 { - log.Printf("debug: no TCP connections found") - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "log" + + "golang.org/x/sys/unix" +) + +func init() { + breakTCPConns = breakTCPConnsDarwin +} + +func breakTCPConnsDarwin() error { + var matched int + for fd := 0; fd < 1000; fd++ { + _, err := unix.GetsockoptTCPConnectionInfo(fd, unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO) + if err == nil { + matched++ + err = unix.Close(fd) + log.Printf("debug: closed TCP fd %v: %v", fd, err) + } + } + if matched == 0 { + log.Printf("debug: no TCP connections found") + } + return nil +} diff --git a/ipn/ipnlocal/breaktcp_linux.go b/ipn/ipnlocal/breaktcp_linux.go index d078103cf..b82f65212 100644 --- a/ipn/ipnlocal/breaktcp_linux.go +++ b/ipn/ipnlocal/breaktcp_linux.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipnlocal - -import ( - "log" - - "golang.org/x/sys/unix" -) - -func init() { - breakTCPConns = breakTCPConnsLinux -} - -func breakTCPConnsLinux() error { - var matched int - for fd := 0; fd < 1000; fd++ { - _, err := unix.GetsockoptTCPInfo(fd, unix.IPPROTO_TCP, unix.TCP_INFO) - if err == nil { - matched++ - err = unix.Close(fd) - log.Printf("debug: closed TCP fd %v: %v", fd, err) - } - } - if matched == 0 { - log.Printf("debug: no TCP connections found") - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "log" + + "golang.org/x/sys/unix" +) + +func init() { + breakTCPConns = breakTCPConnsLinux +} + +func breakTCPConnsLinux() error { + var matched int + for fd := 0; fd < 1000; fd++ { + _, err := unix.GetsockoptTCPInfo(fd, unix.IPPROTO_TCP, unix.TCP_INFO) + if err == nil { + matched++ + err = unix.Close(fd) + log.Printf("debug: closed TCP fd %v: %v", fd, err) + } + } + if matched == 0 { + log.Printf("debug: no TCP connections found") + } + return nil +} diff --git a/ipn/ipnlocal/expiry_test.go b/ipn/ipnlocal/expiry_test.go index efc18133f..af1aa337b 100644 --- a/ipn/ipnlocal/expiry_test.go +++ b/ipn/ipnlocal/expiry_test.go @@ -1,301 +1,301 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipnlocal - -import ( - "fmt" - "reflect" - "strings" - "testing" - "time" - - "tailscale.com/tailcfg" - "tailscale.com/tstest" - "tailscale.com/types/key" - "tailscale.com/types/netmap" -) - -func TestFlagExpiredPeers(t *testing.T) { - n := func(id tailcfg.NodeID, name string, expiry time.Time, mod ...func(*tailcfg.Node)) *tailcfg.Node { - n := &tailcfg.Node{ID: id, Name: name, KeyExpiry: expiry} - for _, f := range mod { - f(n) - } - return n - } - - now := time.Unix(1673373129, 0) - - timeInPast := now.Add(-1 * time.Hour) - timeInFuture := now.Add(1 * time.Hour) - - timeBeforeEpoch := flagExpiredPeersEpoch.Add(-1 * time.Second) - if now.Before(timeBeforeEpoch) { - panic("current time in test cannot be before epoch") - } - - var expiredKey key.NodePublic - if err := expiredKey.UnmarshalText([]byte("nodekey:6da774d5d7740000000000000000000000000000000000000000000000000000")); err != nil { - panic(err) - } - - tests := []struct { - name string - controlTime *time.Time - netmap *netmap.NetworkMap - want []tailcfg.NodeView - }{ - { - name: "no_expiry", - controlTime: &now, - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeInFuture), - }), - }, - want: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeInFuture), - }), - }, - { - name: "expiry", - controlTime: &now, - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeInPast), - }), - }, - want: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeInPast, func(n *tailcfg.Node) { - n.Expired = true - n.Key = expiredKey - }), - }), - }, - { - name: "bad_ControlTime", - // controlTime here is intentionally before our hardcoded epoch - controlTime: &timeBeforeEpoch, - - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeBeforeEpoch.Add(-1*time.Hour)), // before ControlTime - }), - }, - want: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeBeforeEpoch.Add(-1*time.Hour)), // should have expired, but ControlTime is before epoch - }), - }, - { - name: "tagged_node", - controlTime: &now, - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", time.Time{}), // tagged node; zero expiry - }), - }, - want: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", time.Time{}), // not expired - }), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - em := newExpiryManager(t.Logf) - em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) - if tt.controlTime != nil { - em.onControlTime(*tt.controlTime) - } - em.flagExpiredPeers(tt.netmap, now) - if !reflect.DeepEqual(tt.netmap.Peers, tt.want) { - t.Errorf("wrong results\n got: %s\nwant: %s", formatNodes(tt.netmap.Peers), formatNodes(tt.want)) - } - }) - } -} - -func TestNextPeerExpiry(t *testing.T) { - n := func(id tailcfg.NodeID, name string, expiry time.Time, mod ...func(*tailcfg.Node)) *tailcfg.Node { - n := &tailcfg.Node{ID: id, Name: name, KeyExpiry: expiry} - for _, f := range mod { - f(n) - } - return n - } - - now := time.Unix(1675725516, 0) - - noExpiry := time.Time{} - timeInPast := now.Add(-1 * time.Hour) - timeInFuture := now.Add(1 * time.Hour) - timeInMoreFuture := now.Add(2 * time.Hour) - - tests := []struct { - name string - netmap *netmap.NetworkMap - want time.Time - }{ - { - name: "no_expiry", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", noExpiry), - n(2, "bar", noExpiry), - }), - SelfNode: n(3, "self", noExpiry).View(), - }, - want: noExpiry, - }, - { - name: "future_expiry_from_peer", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", noExpiry), - n(2, "bar", timeInFuture), - }), - SelfNode: n(3, "self", noExpiry).View(), - }, - want: timeInFuture, - }, - { - name: "future_expiry_from_self", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", noExpiry), - n(2, "bar", noExpiry), - }), - SelfNode: n(3, "self", timeInFuture).View(), - }, - want: timeInFuture, - }, - { - name: "future_expiry_from_multiple_peers", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - n(2, "bar", timeInMoreFuture), - }), - SelfNode: n(3, "self", noExpiry).View(), - }, - want: timeInFuture, - }, - { - name: "future_expiry_from_peer_and_self", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInMoreFuture), - }), - SelfNode: n(2, "self", timeInFuture).View(), - }, - want: timeInFuture, - }, - { - name: "only_self", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{}), - SelfNode: n(1, "self", timeInFuture).View(), - }, - want: timeInFuture, - }, - { - name: "peer_already_expired", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInPast), - }), - SelfNode: n(2, "self", timeInFuture).View(), - }, - want: timeInFuture, - }, - { - name: "self_already_expired", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInFuture), - }), - SelfNode: n(2, "self", timeInPast).View(), - }, - want: timeInFuture, - }, - { - name: "all_nodes_already_expired", - netmap: &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInPast), - }), - SelfNode: n(2, "self", timeInPast).View(), - }, - want: noExpiry, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - em := newExpiryManager(t.Logf) - em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) - got := em.nextPeerExpiry(tt.netmap, now) - if !got.Equal(tt.want) { - t.Errorf("got %q, want %q", got.Format(time.RFC3339), tt.want.Format(time.RFC3339)) - } else if !got.IsZero() && got.Before(now) { - t.Errorf("unexpectedly got expiry %q before now %q", got.Format(time.RFC3339), now.Format(time.RFC3339)) - } - }) - } - - t.Run("ClockSkew", func(t *testing.T) { - t.Logf("local time: %q", now.Format(time.RFC3339)) - em := newExpiryManager(t.Logf) - em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) - - // The local clock is "running fast"; our clock skew is -2h - em.clockDelta.Store(-2 * time.Hour) - t.Logf("'real' time: %q", now.Add(-2*time.Hour).Format(time.RFC3339)) - - // If we don't adjust for the local time, this would return a - // time in the past. - nm := &netmap.NetworkMap{ - Peers: nodeViews([]*tailcfg.Node{ - n(1, "foo", timeInPast), - }), - } - got := em.nextPeerExpiry(nm, now) - want := now.Add(30 * time.Second) - if !got.Equal(want) { - t.Errorf("got %q, want %q", got.Format(time.RFC3339), want.Format(time.RFC3339)) - } - }) -} - -func formatNodes(nodes []tailcfg.NodeView) string { - var sb strings.Builder - for i, n := range nodes { - if i > 0 { - sb.WriteString(", ") - } - fmt.Fprintf(&sb, "(%d, %q", n.ID(), n.Name()) - - if n.Online() != nil { - fmt.Fprintf(&sb, ", online=%v", *n.Online()) - } - if n.LastSeen() != nil { - fmt.Fprintf(&sb, ", lastSeen=%v", n.LastSeen().Unix()) - } - if n.Key() != (key.NodePublic{}) { - fmt.Fprintf(&sb, ", key=%v", n.Key().String()) - } - if n.Expired() { - fmt.Fprintf(&sb, ", expired=true") - } - sb.WriteString(")") - } - return sb.String() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "fmt" + "reflect" + "strings" + "testing" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/types/key" + "tailscale.com/types/netmap" +) + +func TestFlagExpiredPeers(t *testing.T) { + n := func(id tailcfg.NodeID, name string, expiry time.Time, mod ...func(*tailcfg.Node)) *tailcfg.Node { + n := &tailcfg.Node{ID: id, Name: name, KeyExpiry: expiry} + for _, f := range mod { + f(n) + } + return n + } + + now := time.Unix(1673373129, 0) + + timeInPast := now.Add(-1 * time.Hour) + timeInFuture := now.Add(1 * time.Hour) + + timeBeforeEpoch := flagExpiredPeersEpoch.Add(-1 * time.Second) + if now.Before(timeBeforeEpoch) { + panic("current time in test cannot be before epoch") + } + + var expiredKey key.NodePublic + if err := expiredKey.UnmarshalText([]byte("nodekey:6da774d5d7740000000000000000000000000000000000000000000000000000")); err != nil { + panic(err) + } + + tests := []struct { + name string + controlTime *time.Time + netmap *netmap.NetworkMap + want []tailcfg.NodeView + }{ + { + name: "no_expiry", + controlTime: &now, + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeInFuture), + }), + }, + want: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeInFuture), + }), + }, + { + name: "expiry", + controlTime: &now, + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeInPast), + }), + }, + want: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeInPast, func(n *tailcfg.Node) { + n.Expired = true + n.Key = expiredKey + }), + }), + }, + { + name: "bad_ControlTime", + // controlTime here is intentionally before our hardcoded epoch + controlTime: &timeBeforeEpoch, + + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeBeforeEpoch.Add(-1*time.Hour)), // before ControlTime + }), + }, + want: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeBeforeEpoch.Add(-1*time.Hour)), // should have expired, but ControlTime is before epoch + }), + }, + { + name: "tagged_node", + controlTime: &now, + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", time.Time{}), // tagged node; zero expiry + }), + }, + want: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", time.Time{}), // not expired + }), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + em := newExpiryManager(t.Logf) + em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) + if tt.controlTime != nil { + em.onControlTime(*tt.controlTime) + } + em.flagExpiredPeers(tt.netmap, now) + if !reflect.DeepEqual(tt.netmap.Peers, tt.want) { + t.Errorf("wrong results\n got: %s\nwant: %s", formatNodes(tt.netmap.Peers), formatNodes(tt.want)) + } + }) + } +} + +func TestNextPeerExpiry(t *testing.T) { + n := func(id tailcfg.NodeID, name string, expiry time.Time, mod ...func(*tailcfg.Node)) *tailcfg.Node { + n := &tailcfg.Node{ID: id, Name: name, KeyExpiry: expiry} + for _, f := range mod { + f(n) + } + return n + } + + now := time.Unix(1675725516, 0) + + noExpiry := time.Time{} + timeInPast := now.Add(-1 * time.Hour) + timeInFuture := now.Add(1 * time.Hour) + timeInMoreFuture := now.Add(2 * time.Hour) + + tests := []struct { + name string + netmap *netmap.NetworkMap + want time.Time + }{ + { + name: "no_expiry", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", noExpiry), + n(2, "bar", noExpiry), + }), + SelfNode: n(3, "self", noExpiry).View(), + }, + want: noExpiry, + }, + { + name: "future_expiry_from_peer", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", noExpiry), + n(2, "bar", timeInFuture), + }), + SelfNode: n(3, "self", noExpiry).View(), + }, + want: timeInFuture, + }, + { + name: "future_expiry_from_self", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", noExpiry), + n(2, "bar", noExpiry), + }), + SelfNode: n(3, "self", timeInFuture).View(), + }, + want: timeInFuture, + }, + { + name: "future_expiry_from_multiple_peers", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + n(2, "bar", timeInMoreFuture), + }), + SelfNode: n(3, "self", noExpiry).View(), + }, + want: timeInFuture, + }, + { + name: "future_expiry_from_peer_and_self", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInMoreFuture), + }), + SelfNode: n(2, "self", timeInFuture).View(), + }, + want: timeInFuture, + }, + { + name: "only_self", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{}), + SelfNode: n(1, "self", timeInFuture).View(), + }, + want: timeInFuture, + }, + { + name: "peer_already_expired", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInPast), + }), + SelfNode: n(2, "self", timeInFuture).View(), + }, + want: timeInFuture, + }, + { + name: "self_already_expired", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInFuture), + }), + SelfNode: n(2, "self", timeInPast).View(), + }, + want: timeInFuture, + }, + { + name: "all_nodes_already_expired", + netmap: &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInPast), + }), + SelfNode: n(2, "self", timeInPast).View(), + }, + want: noExpiry, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + em := newExpiryManager(t.Logf) + em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) + got := em.nextPeerExpiry(tt.netmap, now) + if !got.Equal(tt.want) { + t.Errorf("got %q, want %q", got.Format(time.RFC3339), tt.want.Format(time.RFC3339)) + } else if !got.IsZero() && got.Before(now) { + t.Errorf("unexpectedly got expiry %q before now %q", got.Format(time.RFC3339), now.Format(time.RFC3339)) + } + }) + } + + t.Run("ClockSkew", func(t *testing.T) { + t.Logf("local time: %q", now.Format(time.RFC3339)) + em := newExpiryManager(t.Logf) + em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) + + // The local clock is "running fast"; our clock skew is -2h + em.clockDelta.Store(-2 * time.Hour) + t.Logf("'real' time: %q", now.Add(-2*time.Hour).Format(time.RFC3339)) + + // If we don't adjust for the local time, this would return a + // time in the past. + nm := &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + n(1, "foo", timeInPast), + }), + } + got := em.nextPeerExpiry(nm, now) + want := now.Add(30 * time.Second) + if !got.Equal(want) { + t.Errorf("got %q, want %q", got.Format(time.RFC3339), want.Format(time.RFC3339)) + } + }) +} + +func formatNodes(nodes []tailcfg.NodeView) string { + var sb strings.Builder + for i, n := range nodes { + if i > 0 { + sb.WriteString(", ") + } + fmt.Fprintf(&sb, "(%d, %q", n.ID(), n.Name()) + + if n.Online() != nil { + fmt.Fprintf(&sb, ", online=%v", *n.Online()) + } + if n.LastSeen() != nil { + fmt.Fprintf(&sb, ", lastSeen=%v", n.LastSeen().Unix()) + } + if n.Key() != (key.NodePublic{}) { + fmt.Fprintf(&sb, ", key=%v", n.Key().String()) + } + if n.Expired() { + fmt.Fprintf(&sb, ", expired=true") + } + sb.WriteString(")") + } + return sb.String() +} diff --git a/ipn/ipnlocal/peerapi_h2c.go b/ipn/ipnlocal/peerapi_h2c.go index e6335fe2b..fbfa86398 100644 --- a/ipn/ipnlocal/peerapi_h2c.go +++ b/ipn/ipnlocal/peerapi_h2c.go @@ -1,20 +1,20 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !ios && !android && !js - -package ipnlocal - -import ( - "net/http" - - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" -) - -func init() { - addH2C = func(s *http.Server) { - h2s := &http2.Server{} - s.Handler = h2c.NewHandler(s.Handler, h2s) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !android && !js + +package ipnlocal + +import ( + "net/http" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +func init() { + addH2C = func(s *http.Server) { + h2s := &http2.Server{} + s.Handler = h2c.NewHandler(s.Handler, h2s) + } +} diff --git a/ipn/ipnlocal/testdata/example.com-key.pem b/ipn/ipnlocal/testdata/example.com-key.pem index 9020553f1..06902f4c9 100644 --- a/ipn/ipnlocal/testdata/example.com-key.pem +++ b/ipn/ipnlocal/testdata/example.com-key.pem @@ -1,28 +1,28 @@ ------BEGIN PRIVATE KEY----- -MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCejQaJrntrJSgE -QtScyTU6TXOU+v1FdFjrsyHFK5mjV1C5pVQxnLn93GRshtIrGOLLrd3Wv2TVYZOX -xH7f1ZLFbneDURCXbS+7nmsg+TLHRSRKfODbE3oYZj7NSJ163CCvwSJKTdmLpXbn -ui9F04tyk0zxO4Wre4ukwf6xtse8G5zl2RJrueiVAiouTG/pJdIS08dGQa0GM1n9 -Aesa+TerlZcpRZR6X402yQqa8q/QqbIuzrlfDmgOb8sm6T8+JMtj3hEvnYdpMVOg -w/XiTlX0v/YrB9sVQ9XnqGsqwTL0OMG0choMNKipwLi2n+XPSCIiRhi666zNNivE -K1qaPS5RAgMBAAECggEAV9dAGQWPISR70CiKjLa5A60nbRHFQjackTE0c32daC6W -7dOYGsh/DxOMm8fyJqhp9nhEYJa3MbUWxU27ER3NbA6wrhM6gvqeKG8zYRhPNrGq -0o3vMdDPozb6cldZ0Fimz1jMO6h373NjtiyjxibWqkrLpRbaDtCq5EQKbMEcVa2D -Xt5hxCOaCA3OZ/mAcGUNFmDNgNsGP/r6eXdI5pbqnUNMPkv/JsHl8h2HuyKUm4hf -TRnXPAak6DkUod9QXYFKVBVPa5pjiO09e0aiMUvJ8vYd/6bNIsAKWLPa1PYuUE2l -kg8Nik+P/XLzffKsLxiFKY0nCqrorM9K5q7baofGdQKBgQDPujjebFg6OKw6MS3S -PESopvL//C/XgtgifcSSZCWzIZRVBVTbbJCGRtqFzF0XO4YRX3EOAyD/L7wYUPzO -+W3AU2W3/DVJYdcm2CASABbHNy0kk52LI0HHAssbFDgyB9XuuWP+vVZk7B5OmCAD -Bppuj6Mnu03i282nKNJzvRiVnwKBgQDDZUXv22K8y7GkKw/ZW/wQP2zBNtFc15he -1EOyUGHlXuQixnDSaqonkwec6IOlo7Sx/vwO/7+v4Jzc24Wq3DFAmMu/EYJgvI+m -m3kpB4H7Xus4JqnhxqN7GB7zOdguCWZF1HLemZNZlVrUjG5mQ9cizzvvYptnQDLq -FEJ1hddWDwKBgB+vy276Xfb7oCH8UH4KXXrQhK7RvEaGmgug3bRq/Gk3zRWvC4Ox -KtagxkK0qtqZZNkPkwJNLeJfWLTo3beAyuIUlqabHVHFT/mH7FRymQbofsVekyCf -TzBZV7wYuH3BPjv9IajBHwWkEvdwMyni/vmwhXXRF49schF2o6uuA6sHAoGBAL1J -Xnb+EKjUq0JedPwcIBOdXb3PXQKT2QgEmZAkTrHlOxx1INa2fh/YT4ext9a+wE2u -tn/RQeEfttY90z+yEASEAN0YGTWddYvxEW6t1z2stjGvQuN1ium0dEcrwkDW2jzL -knwSSqx+A3/kiw6GqeMO3wEIhYOArdIVzkwLXJABAoGAOXLGhz5u5FWjF3zAeYme -uHTU/3Z3jeI80PvShGrgAakPOBt3cIFpUaiOEslcqqgDUSGE3EnmkRqaEch+UapF -ty6Zz7cKjXhQSWOjew1uUW2ANNEpsnYbmZOOnfvosd7jfHSVbL6KIhWmIdC6h0NP -c/bJnTXEEVsWjLZTwYaq0Us= +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCejQaJrntrJSgE +QtScyTU6TXOU+v1FdFjrsyHFK5mjV1C5pVQxnLn93GRshtIrGOLLrd3Wv2TVYZOX +xH7f1ZLFbneDURCXbS+7nmsg+TLHRSRKfODbE3oYZj7NSJ163CCvwSJKTdmLpXbn +ui9F04tyk0zxO4Wre4ukwf6xtse8G5zl2RJrueiVAiouTG/pJdIS08dGQa0GM1n9 +Aesa+TerlZcpRZR6X402yQqa8q/QqbIuzrlfDmgOb8sm6T8+JMtj3hEvnYdpMVOg +w/XiTlX0v/YrB9sVQ9XnqGsqwTL0OMG0choMNKipwLi2n+XPSCIiRhi666zNNivE +K1qaPS5RAgMBAAECggEAV9dAGQWPISR70CiKjLa5A60nbRHFQjackTE0c32daC6W +7dOYGsh/DxOMm8fyJqhp9nhEYJa3MbUWxU27ER3NbA6wrhM6gvqeKG8zYRhPNrGq +0o3vMdDPozb6cldZ0Fimz1jMO6h373NjtiyjxibWqkrLpRbaDtCq5EQKbMEcVa2D +Xt5hxCOaCA3OZ/mAcGUNFmDNgNsGP/r6eXdI5pbqnUNMPkv/JsHl8h2HuyKUm4hf +TRnXPAak6DkUod9QXYFKVBVPa5pjiO09e0aiMUvJ8vYd/6bNIsAKWLPa1PYuUE2l +kg8Nik+P/XLzffKsLxiFKY0nCqrorM9K5q7baofGdQKBgQDPujjebFg6OKw6MS3S +PESopvL//C/XgtgifcSSZCWzIZRVBVTbbJCGRtqFzF0XO4YRX3EOAyD/L7wYUPzO ++W3AU2W3/DVJYdcm2CASABbHNy0kk52LI0HHAssbFDgyB9XuuWP+vVZk7B5OmCAD +Bppuj6Mnu03i282nKNJzvRiVnwKBgQDDZUXv22K8y7GkKw/ZW/wQP2zBNtFc15he +1EOyUGHlXuQixnDSaqonkwec6IOlo7Sx/vwO/7+v4Jzc24Wq3DFAmMu/EYJgvI+m +m3kpB4H7Xus4JqnhxqN7GB7zOdguCWZF1HLemZNZlVrUjG5mQ9cizzvvYptnQDLq +FEJ1hddWDwKBgB+vy276Xfb7oCH8UH4KXXrQhK7RvEaGmgug3bRq/Gk3zRWvC4Ox +KtagxkK0qtqZZNkPkwJNLeJfWLTo3beAyuIUlqabHVHFT/mH7FRymQbofsVekyCf +TzBZV7wYuH3BPjv9IajBHwWkEvdwMyni/vmwhXXRF49schF2o6uuA6sHAoGBAL1J +Xnb+EKjUq0JedPwcIBOdXb3PXQKT2QgEmZAkTrHlOxx1INa2fh/YT4ext9a+wE2u +tn/RQeEfttY90z+yEASEAN0YGTWddYvxEW6t1z2stjGvQuN1ium0dEcrwkDW2jzL +knwSSqx+A3/kiw6GqeMO3wEIhYOArdIVzkwLXJABAoGAOXLGhz5u5FWjF3zAeYme +uHTU/3Z3jeI80PvShGrgAakPOBt3cIFpUaiOEslcqqgDUSGE3EnmkRqaEch+UapF +ty6Zz7cKjXhQSWOjew1uUW2ANNEpsnYbmZOOnfvosd7jfHSVbL6KIhWmIdC6h0NP +c/bJnTXEEVsWjLZTwYaq0Us= -----END PRIVATE KEY----- \ No newline at end of file diff --git a/ipn/ipnlocal/testdata/example.com.pem b/ipn/ipnlocal/testdata/example.com.pem index 65e7110a8..588850813 100644 --- a/ipn/ipnlocal/testdata/example.com.pem +++ b/ipn/ipnlocal/testdata/example.com.pem @@ -1,26 +1,26 @@ ------BEGIN CERTIFICATE----- -MIIEcDCCAtigAwIBAgIRAPmUKRkyFAkVVxFblB/233cwDQYJKoZIhvcNAQELBQAw -gZ8xHjAcBgNVBAoTFW1rY2VydCBkZXZlbG9wbWVudCBDQTE6MDgGA1UECwwxZnJv -bWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWljaGFlbCBKLiBGcm9tYmVyZ2VyKTFB -MD8GA1UEAww4bWtjZXJ0IGZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hh -ZWwgSi4gRnJvbWJlcmdlcikwHhcNMjMwMjA3MjAzNDE4WhcNMjUwNTA3MTkzNDE4 -WjBlMScwJQYDVQQKEx5ta2NlcnQgZGV2ZWxvcG1lbnQgY2VydGlmaWNhdGUxOjA4 -BgNVBAsMMWZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hhZWwgSi4gRnJv -bWJlcmdlcikwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCejQaJrntr -JSgEQtScyTU6TXOU+v1FdFjrsyHFK5mjV1C5pVQxnLn93GRshtIrGOLLrd3Wv2TV -YZOXxH7f1ZLFbneDURCXbS+7nmsg+TLHRSRKfODbE3oYZj7NSJ163CCvwSJKTdmL -pXbnui9F04tyk0zxO4Wre4ukwf6xtse8G5zl2RJrueiVAiouTG/pJdIS08dGQa0G -M1n9Aesa+TerlZcpRZR6X402yQqa8q/QqbIuzrlfDmgOb8sm6T8+JMtj3hEvnYdp -MVOgw/XiTlX0v/YrB9sVQ9XnqGsqwTL0OMG0choMNKipwLi2n+XPSCIiRhi666zN -NivEK1qaPS5RAgMBAAGjYDBeMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUEDDAKBggr -BgEFBQcDATAfBgNVHSMEGDAWgBTXyq2jQVrnqQKL8fB9C4L0QJftwDAWBgNVHREE -DzANggtleGFtcGxlLmNvbTANBgkqhkiG9w0BAQsFAAOCAYEAQWzpOaBkRR4M+WqB -CsT4ARyM6WpZ+jpeSblCzPdlDRW+50G1HV7K930zayq4DwncPY/SqSn0Q31WuzZv -bTWHkWa+MLPGYANHsusOmMR8Eh16G4+5+GGf8psWa0npAYO35cuNkyyCCc1LEB4M -NrzCB2+KZ+SyOdfCCA5VzEKN3I8wvVLaYovi24Zjwv+0uETG92TlZmLQRhj8uPxN -deeLM45aBkQZSYCbGMDVDK/XYKBkNLn3kxD/eZeXxxr41v4pH44+46FkYcYJzdn8 -ccAg5LRGieqTozhLiXARNK1vTy6kR1l/Az8DIx6GN4sP2/LMFYFijiiOCDKS1wWA -xQgZeHt4GIuBym+Kd+Z5KXcP0AT+47Cby3+B10Kq8vHwjTELiF0UFeEYYMdynPAW -pbEwVLhsfMsBqFtj3dsxHr8Kz3rnarOYzkaw7EMZnLAthb2CN7y5uGV9imQC5RMI -/qZdRSuCYZ3A1E/WJkGbPY/YdPql/IE+LIAgKGFHZZNftBCo +-----BEGIN CERTIFICATE----- +MIIEcDCCAtigAwIBAgIRAPmUKRkyFAkVVxFblB/233cwDQYJKoZIhvcNAQELBQAw +gZ8xHjAcBgNVBAoTFW1rY2VydCBkZXZlbG9wbWVudCBDQTE6MDgGA1UECwwxZnJv +bWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWljaGFlbCBKLiBGcm9tYmVyZ2VyKTFB +MD8GA1UEAww4bWtjZXJ0IGZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hh +ZWwgSi4gRnJvbWJlcmdlcikwHhcNMjMwMjA3MjAzNDE4WhcNMjUwNTA3MTkzNDE4 +WjBlMScwJQYDVQQKEx5ta2NlcnQgZGV2ZWxvcG1lbnQgY2VydGlmaWNhdGUxOjA4 +BgNVBAsMMWZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hhZWwgSi4gRnJv +bWJlcmdlcikwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCejQaJrntr +JSgEQtScyTU6TXOU+v1FdFjrsyHFK5mjV1C5pVQxnLn93GRshtIrGOLLrd3Wv2TV +YZOXxH7f1ZLFbneDURCXbS+7nmsg+TLHRSRKfODbE3oYZj7NSJ163CCvwSJKTdmL +pXbnui9F04tyk0zxO4Wre4ukwf6xtse8G5zl2RJrueiVAiouTG/pJdIS08dGQa0G +M1n9Aesa+TerlZcpRZR6X402yQqa8q/QqbIuzrlfDmgOb8sm6T8+JMtj3hEvnYdp +MVOgw/XiTlX0v/YrB9sVQ9XnqGsqwTL0OMG0choMNKipwLi2n+XPSCIiRhi666zN +NivEK1qaPS5RAgMBAAGjYDBeMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUEDDAKBggr +BgEFBQcDATAfBgNVHSMEGDAWgBTXyq2jQVrnqQKL8fB9C4L0QJftwDAWBgNVHREE +DzANggtleGFtcGxlLmNvbTANBgkqhkiG9w0BAQsFAAOCAYEAQWzpOaBkRR4M+WqB +CsT4ARyM6WpZ+jpeSblCzPdlDRW+50G1HV7K930zayq4DwncPY/SqSn0Q31WuzZv +bTWHkWa+MLPGYANHsusOmMR8Eh16G4+5+GGf8psWa0npAYO35cuNkyyCCc1LEB4M +NrzCB2+KZ+SyOdfCCA5VzEKN3I8wvVLaYovi24Zjwv+0uETG92TlZmLQRhj8uPxN +deeLM45aBkQZSYCbGMDVDK/XYKBkNLn3kxD/eZeXxxr41v4pH44+46FkYcYJzdn8 +ccAg5LRGieqTozhLiXARNK1vTy6kR1l/Az8DIx6GN4sP2/LMFYFijiiOCDKS1wWA +xQgZeHt4GIuBym+Kd+Z5KXcP0AT+47Cby3+B10Kq8vHwjTELiF0UFeEYYMdynPAW +pbEwVLhsfMsBqFtj3dsxHr8Kz3rnarOYzkaw7EMZnLAthb2CN7y5uGV9imQC5RMI +/qZdRSuCYZ3A1E/WJkGbPY/YdPql/IE+LIAgKGFHZZNftBCo -----END CERTIFICATE----- \ No newline at end of file diff --git a/ipn/ipnlocal/testdata/rootCA.pem b/ipn/ipnlocal/testdata/rootCA.pem index 28bd25467..88a16f47a 100644 --- a/ipn/ipnlocal/testdata/rootCA.pem +++ b/ipn/ipnlocal/testdata/rootCA.pem @@ -1,30 +1,30 @@ ------BEGIN CERTIFICATE----- -MIIFEDCCA3igAwIBAgIRANf5NdPojIfj70wMfJVYUg8wDQYJKoZIhvcNAQELBQAw -gZ8xHjAcBgNVBAoTFW1rY2VydCBkZXZlbG9wbWVudCBDQTE6MDgGA1UECwwxZnJv -bWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWljaGFlbCBKLiBGcm9tYmVyZ2VyKTFB -MD8GA1UEAww4bWtjZXJ0IGZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hh -ZWwgSi4gRnJvbWJlcmdlcikwHhcNMjMwMjA3MjAzNDE4WhcNMzMwMjA3MjAzNDE4 -WjCBnzEeMBwGA1UEChMVbWtjZXJ0IGRldmVsb3BtZW50IENBMTowOAYDVQQLDDFm -cm9tYmVyZ2VyQHN0YXJkdXN0LmxvY2FsIChNaWNoYWVsIEouIEZyb21iZXJnZXIp -MUEwPwYDVQQDDDhta2NlcnQgZnJvbWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWlj -aGFlbCBKLiBGcm9tYmVyZ2VyKTCCAaIwDQYJKoZIhvcNAQEBBQADggGPADCCAYoC -ggGBAL5uXNnrZ6dgjcvK0Hc7ZNUIRYEWst9qbO0P9H7le08pJ6d9T2BUWruZtVjk -Q12msv5/bVWHhVk8dZclI9FLXuMsIrocH8bsoP4wruPMyRyp6EedSKODN51fFSRv -/jHbS5vzUVAWTYy9qYmd6qL0uhsHCZCCT6gfigamHPUFKM3sHDn5ZHWvySMwcyGl -AicmPAIkBWqiCZAkB5+WM7+oyRLjmrIalfWIZYxW/rojGLwTfneHv6J5WjVQnpJB -ayWCzCzaiXukK9MeBWeTOe8UfVN0Engd74/rjLWvjbfC+uZSr6RVkZvs2jANLwPF -zgzBPHgRPfAhszU1NNAMjnNQ47+OMOTKRt7e6jYzhO5fyO1qVAAvGBqcfpj+JfDk -cccaUMhUvdiGrhGf1V1tN/PislxvALirzcFipjD01isBKwn0fxRugzvJNrjEo8RA -RvbcdeKcwex7M0o/Cd0+G2B13gZNOFvR33PmG7iTpp7IUrUKfQg28I83Sp8tMY3s -ljJSawIDAQABo0UwQzAOBgNVHQ8BAf8EBAMCAgQwEgYDVR0TAQH/BAgwBgEB/wIB -ADAdBgNVHQ4EFgQU18qto0Fa56kCi/HwfQuC9ECX7cAwDQYJKoZIhvcNAQELBQAD -ggGBAAzs96LwZVOsRSlBdQqMo8oMAvs7HgnYbXt8SqaACLX3+kJ3cV/vrCE3iJrW -ma4CiQbxS/HqsiZjota5m4lYeEevRnUDpXhp+7ugZTiz33Flm1RU99c9UYfQ+919 -ANPAKeqNpoPco/HF5Bz0ocepjcfKQrVZZNTj6noLs8o12FHBLO5976AcF9mqlNfh -8/F0gDJXq6+x7VT5y8u0rY004XKPRe3CklRt8kpeMiP6mhRyyUehOaHeIbNx8ubi -Pi44ByN/ueAnuRhF9zYtyZVZZOaSLysJge01tuPXF8rBXGruoJIv35xTTBa9BzaP -YDOGbGn1ZnajdNagHqCba8vjTLDSpqMvgRj3TFrGHdETA2LDQat38uVxX8gxm68K -va5Tyv7n+6BQ5YTpJjTPnmSJKaXZrrhdLPvG0OU2TxeEsvbcm5LFQofirOOw86Se -vzF2cQ94mmHRZiEk0Av3NO0jF93ELDrBCuiccVyEKq6TknuvPQlutCXKDOYSEb8I -MHctBg== +-----BEGIN CERTIFICATE----- +MIIFEDCCA3igAwIBAgIRANf5NdPojIfj70wMfJVYUg8wDQYJKoZIhvcNAQELBQAw +gZ8xHjAcBgNVBAoTFW1rY2VydCBkZXZlbG9wbWVudCBDQTE6MDgGA1UECwwxZnJv +bWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWljaGFlbCBKLiBGcm9tYmVyZ2VyKTFB +MD8GA1UEAww4bWtjZXJ0IGZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hh +ZWwgSi4gRnJvbWJlcmdlcikwHhcNMjMwMjA3MjAzNDE4WhcNMzMwMjA3MjAzNDE4 +WjCBnzEeMBwGA1UEChMVbWtjZXJ0IGRldmVsb3BtZW50IENBMTowOAYDVQQLDDFm +cm9tYmVyZ2VyQHN0YXJkdXN0LmxvY2FsIChNaWNoYWVsIEouIEZyb21iZXJnZXIp +MUEwPwYDVQQDDDhta2NlcnQgZnJvbWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWlj +aGFlbCBKLiBGcm9tYmVyZ2VyKTCCAaIwDQYJKoZIhvcNAQEBBQADggGPADCCAYoC +ggGBAL5uXNnrZ6dgjcvK0Hc7ZNUIRYEWst9qbO0P9H7le08pJ6d9T2BUWruZtVjk +Q12msv5/bVWHhVk8dZclI9FLXuMsIrocH8bsoP4wruPMyRyp6EedSKODN51fFSRv +/jHbS5vzUVAWTYy9qYmd6qL0uhsHCZCCT6gfigamHPUFKM3sHDn5ZHWvySMwcyGl +AicmPAIkBWqiCZAkB5+WM7+oyRLjmrIalfWIZYxW/rojGLwTfneHv6J5WjVQnpJB +ayWCzCzaiXukK9MeBWeTOe8UfVN0Engd74/rjLWvjbfC+uZSr6RVkZvs2jANLwPF +zgzBPHgRPfAhszU1NNAMjnNQ47+OMOTKRt7e6jYzhO5fyO1qVAAvGBqcfpj+JfDk +cccaUMhUvdiGrhGf1V1tN/PislxvALirzcFipjD01isBKwn0fxRugzvJNrjEo8RA +RvbcdeKcwex7M0o/Cd0+G2B13gZNOFvR33PmG7iTpp7IUrUKfQg28I83Sp8tMY3s +ljJSawIDAQABo0UwQzAOBgNVHQ8BAf8EBAMCAgQwEgYDVR0TAQH/BAgwBgEB/wIB +ADAdBgNVHQ4EFgQU18qto0Fa56kCi/HwfQuC9ECX7cAwDQYJKoZIhvcNAQELBQAD +ggGBAAzs96LwZVOsRSlBdQqMo8oMAvs7HgnYbXt8SqaACLX3+kJ3cV/vrCE3iJrW +ma4CiQbxS/HqsiZjota5m4lYeEevRnUDpXhp+7ugZTiz33Flm1RU99c9UYfQ+919 +ANPAKeqNpoPco/HF5Bz0ocepjcfKQrVZZNTj6noLs8o12FHBLO5976AcF9mqlNfh +8/F0gDJXq6+x7VT5y8u0rY004XKPRe3CklRt8kpeMiP6mhRyyUehOaHeIbNx8ubi +Pi44ByN/ueAnuRhF9zYtyZVZZOaSLysJge01tuPXF8rBXGruoJIv35xTTBa9BzaP +YDOGbGn1ZnajdNagHqCba8vjTLDSpqMvgRj3TFrGHdETA2LDQat38uVxX8gxm68K +va5Tyv7n+6BQ5YTpJjTPnmSJKaXZrrhdLPvG0OU2TxeEsvbcm5LFQofirOOw86Se +vzF2cQ94mmHRZiEk0Av3NO0jF93ELDrBCuiccVyEKq6TknuvPQlutCXKDOYSEb8I +MHctBg== -----END CERTIFICATE----- \ No newline at end of file diff --git a/ipn/ipnserver/proxyconnect_js.go b/ipn/ipnserver/proxyconnect_js.go index 27448fa0d..368221e22 100644 --- a/ipn/ipnserver/proxyconnect_js.go +++ b/ipn/ipnserver/proxyconnect_js.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipnserver - -import "net/http" - -func (s *Server) handleProxyConnectConn(w http.ResponseWriter, r *http.Request) { - panic("unreachable") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnserver + +import "net/http" + +func (s *Server) handleProxyConnectConn(w http.ResponseWriter, r *http.Request) { + panic("unreachable") +} diff --git a/ipn/ipnserver/server_test.go b/ipn/ipnserver/server_test.go index 49fb4d01f..b7d5ea144 100644 --- a/ipn/ipnserver/server_test.go +++ b/ipn/ipnserver/server_test.go @@ -1,46 +1,46 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipnserver - -import ( - "context" - "sync" - "testing" -) - -func TestWaiterSet(t *testing.T) { - var s waiterSet - - wantLen := func(want int, when string) { - t.Helper() - if got := len(s); got != want { - t.Errorf("%s: len = %v; want %v", when, got, want) - } - } - wantLen(0, "initial") - var mu sync.Mutex - ctx, cancel := context.WithCancel(context.Background()) - - ready, cleanup := s.add(&mu, ctx) - wantLen(1, "after add") - - select { - case <-ready: - t.Fatal("should not be ready") - default: - } - s.wakeAll() - <-ready - - wantLen(1, "after fire") - cleanup() - wantLen(0, "after cleanup") - - // And again but on an already-expired ctx. - cancel() - ready, cleanup = s.add(&mu, ctx) - <-ready // shouldn't block - cleanup() - wantLen(0, "at end") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnserver + +import ( + "context" + "sync" + "testing" +) + +func TestWaiterSet(t *testing.T) { + var s waiterSet + + wantLen := func(want int, when string) { + t.Helper() + if got := len(s); got != want { + t.Errorf("%s: len = %v; want %v", when, got, want) + } + } + wantLen(0, "initial") + var mu sync.Mutex + ctx, cancel := context.WithCancel(context.Background()) + + ready, cleanup := s.add(&mu, ctx) + wantLen(1, "after add") + + select { + case <-ready: + t.Fatal("should not be ready") + default: + } + s.wakeAll() + <-ready + + wantLen(1, "after fire") + cleanup() + wantLen(0, "after cleanup") + + // And again but on an already-expired ctx. + cancel() + ready, cleanup = s.add(&mu, ctx) + <-ready // shouldn't block + cleanup() + wantLen(0, "at end") +} diff --git a/ipn/localapi/disabled_stubs.go b/ipn/localapi/disabled_stubs.go index 230553c14..c744f34d5 100644 --- a/ipn/localapi/disabled_stubs.go +++ b/ipn/localapi/disabled_stubs.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ios || android || js - -package localapi - -import ( - "net/http" - "runtime" -) - -func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { - http.Error(w, "disabled on "+runtime.GOOS, http.StatusNotFound) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ios || android || js + +package localapi + +import ( + "net/http" + "runtime" +) + +func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { + http.Error(w, "disabled on "+runtime.GOOS, http.StatusNotFound) +} diff --git a/ipn/localapi/pprof.go b/ipn/localapi/pprof.go index 5cc4daca1..8c9429b31 100644 --- a/ipn/localapi/pprof.go +++ b/ipn/localapi/pprof.go @@ -1,28 +1,28 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !ios && !android && !js - -// We don't include it on mobile where we're more memory constrained and -// there's no CLI to get at the results anyway. - -package localapi - -import ( - "net/http" - "net/http/pprof" -) - -func init() { - servePprofFunc = servePprof -} - -func servePprof(w http.ResponseWriter, r *http.Request) { - name := r.FormValue("name") - switch name { - case "profile": - pprof.Profile(w, r) - default: - pprof.Handler(name).ServeHTTP(w, r) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !android && !js + +// We don't include it on mobile where we're more memory constrained and +// there's no CLI to get at the results anyway. + +package localapi + +import ( + "net/http" + "net/http/pprof" +) + +func init() { + servePprofFunc = servePprof +} + +func servePprof(w http.ResponseWriter, r *http.Request) { + name := r.FormValue("name") + switch name { + case "profile": + pprof.Profile(w, r) + default: + pprof.Handler(name).ServeHTTP(w, r) + } +} diff --git a/ipn/policy/policy.go b/ipn/policy/policy.go index 834706f31..494a0dc40 100644 --- a/ipn/policy/policy.go +++ b/ipn/policy/policy.go @@ -1,47 +1,47 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package policy contains various policy decisions that need to be -// shared between the node client & control server. -package policy - -import ( - "tailscale.com/tailcfg" -) - -// IsInterestingService reports whether service s on the given operating -// system (a version.OS value) is an interesting enough port to report -// to our peer nodes for discovery purposes. -func IsInterestingService(s tailcfg.Service, os string) bool { - switch s.Proto { - case tailcfg.PeerAPI4, tailcfg.PeerAPI6, tailcfg.PeerAPIDNS: - return true - } - if s.Proto != tailcfg.TCP { - return false - } - if os != "windows" { - // For non-Windows machines, assume all TCP listeners - // are interesting enough. We don't see listener spam - // there. - return true - } - // Windows has tons of TCP listeners. We need to move to a denylist - // model later, but for now we just allow some common ones: - switch s.Port { - case 22, // ssh - 80, // http - 443, // https (but no hostname, so little useless) - 3389, // rdp - 5900, // vnc - 32400, // plex - - // And now some arbitrary HTTP dev server ports: - // Eventually we'll remove this and make all ports - // work, once we nicely filter away noisy system - // ports. - 8000, 8080, 8443, 8888: - return true - } - return false -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package policy contains various policy decisions that need to be +// shared between the node client & control server. +package policy + +import ( + "tailscale.com/tailcfg" +) + +// IsInterestingService reports whether service s on the given operating +// system (a version.OS value) is an interesting enough port to report +// to our peer nodes for discovery purposes. +func IsInterestingService(s tailcfg.Service, os string) bool { + switch s.Proto { + case tailcfg.PeerAPI4, tailcfg.PeerAPI6, tailcfg.PeerAPIDNS: + return true + } + if s.Proto != tailcfg.TCP { + return false + } + if os != "windows" { + // For non-Windows machines, assume all TCP listeners + // are interesting enough. We don't see listener spam + // there. + return true + } + // Windows has tons of TCP listeners. We need to move to a denylist + // model later, but for now we just allow some common ones: + switch s.Port { + case 22, // ssh + 80, // http + 443, // https (but no hostname, so little useless) + 3389, // rdp + 5900, // vnc + 32400, // plex + + // And now some arbitrary HTTP dev server ports: + // Eventually we'll remove this and make all ports + // work, once we nicely filter away noisy system + // ports. + 8000, 8080, 8443, 8888: + return true + } + return false +} diff --git a/ipn/store/awsstore/store_aws.go b/ipn/store/awsstore/store_aws.go index 84059af67..0fb78d45a 100644 --- a/ipn/store/awsstore/store_aws.go +++ b/ipn/store/awsstore/store_aws.go @@ -1,186 +1,186 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux && !ts_omit_aws - -// Package awsstore contains an ipn.StateStore implementation using AWS SSM. -package awsstore - -import ( - "context" - "errors" - "fmt" - "regexp" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/aws/arn" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/service/ssm" - ssmTypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" - "tailscale.com/ipn" - "tailscale.com/ipn/store/mem" - "tailscale.com/types/logger" -) - -const ( - parameterNameRxStr = `^parameter(/.*)` -) - -var parameterNameRx = regexp.MustCompile(parameterNameRxStr) - -// awsSSMClient is an interface allowing us to mock the couple of -// API calls we are leveraging with the AWSStore provider -type awsSSMClient interface { - GetParameter(ctx context.Context, - params *ssm.GetParameterInput, - optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) - - PutParameter(ctx context.Context, - params *ssm.PutParameterInput, - optFns ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) -} - -// store is a store which leverages AWS SSM parameter store -// to persist the state -type awsStore struct { - ssmClient awsSSMClient - ssmARN arn.ARN - - memory mem.Store -} - -// New returns a new ipn.StateStore using the AWS SSM storage -// location given by ssmARN. -// -// Note that we store the entire store in a single parameter -// key, therefore if the state is above 8kb, it can cause -// Tailscaled to only only store new state in-memory and -// restarting Tailscaled can fail until you delete your state -// from the AWS Parameter Store. -func New(_ logger.Logf, ssmARN string) (ipn.StateStore, error) { - return newStore(ssmARN, nil) -} - -// newStore is NewStore, but for tests. If client is non-nil, it's -// used instead of making one. -func newStore(ssmARN string, client awsSSMClient) (ipn.StateStore, error) { - s := &awsStore{ - ssmClient: client, - } - - var err error - - // Parse the ARN - if s.ssmARN, err = arn.Parse(ssmARN); err != nil { - return nil, fmt.Errorf("unable to parse the ARN correctly: %v", err) - } - - // Validate the ARN corresponds to the SSM service - if s.ssmARN.Service != "ssm" { - return nil, fmt.Errorf("invalid service %q, expected 'ssm'", s.ssmARN.Service) - } - - // Validate the ARN corresponds to a parameter store resource - if !parameterNameRx.MatchString(s.ssmARN.Resource) { - return nil, fmt.Errorf("invalid resource %q, expected to match %v", s.ssmARN.Resource, parameterNameRxStr) - } - - if s.ssmClient == nil { - var cfg aws.Config - if cfg, err = config.LoadDefaultConfig( - context.TODO(), - config.WithRegion(s.ssmARN.Region), - ); err != nil { - return nil, err - } - s.ssmClient = ssm.NewFromConfig(cfg) - } - - // Hydrate cache with the potentially current state - if err := s.LoadState(); err != nil { - return nil, err - } - return s, nil - -} - -// LoadState attempts to read the state from AWS SSM parameter store key. -func (s *awsStore) LoadState() error { - param, err := s.ssmClient.GetParameter( - context.TODO(), - &ssm.GetParameterInput{ - Name: aws.String(s.ParameterName()), - WithDecryption: aws.Bool(true), - }, - ) - - if err != nil { - var pnf *ssmTypes.ParameterNotFound - if errors.As(err, &pnf) { - // Create the parameter as it does not exist yet - // and return directly as it is defacto empty - return s.persistState() - } - return err - } - - // Load the content in-memory - return s.memory.LoadFromJSON([]byte(*param.Parameter.Value)) -} - -// ParameterName returns the parameter name extracted from -// the provided ARN -func (s *awsStore) ParameterName() (name string) { - values := parameterNameRx.FindStringSubmatch(s.ssmARN.Resource) - if len(values) == 2 { - name = values[1] - } - return -} - -// String returns the awsStore and the ARN of the SSM parameter store -// configured to store the state -func (s *awsStore) String() string { return fmt.Sprintf("awsStore(%q)", s.ssmARN.String()) } - -// ReadState implements the Store interface. -func (s *awsStore) ReadState(id ipn.StateKey) (bs []byte, err error) { - return s.memory.ReadState(id) -} - -// WriteState implements the Store interface. -func (s *awsStore) WriteState(id ipn.StateKey, bs []byte) (err error) { - // Write the state in-memory - if err = s.memory.WriteState(id, bs); err != nil { - return - } - - // Persist the state in AWS SSM parameter store - return s.persistState() -} - -// PersistState saves the states into the AWS SSM parameter store -func (s *awsStore) persistState() error { - // Generate JSON from in-memory cache - bs, err := s.memory.ExportToJSON() - if err != nil { - return err - } - - // Store in AWS SSM parameter store. - // - // We use intelligent tiering so that when the state is below 4kb, it uses Standard tiering - // which is free. However, if it exceeds 4kb it switches the parameter to advanced tiering - // doubling the capacity to 8kb per the following docs: - // https://aws.amazon.com/about-aws/whats-new/2019/08/aws-systems-manager-parameter-store-announces-intelligent-tiering-to-enable-automatic-parameter-tier-selection/ - _, err = s.ssmClient.PutParameter( - context.TODO(), - &ssm.PutParameterInput{ - Name: aws.String(s.ParameterName()), - Value: aws.String(string(bs)), - Overwrite: aws.Bool(true), - Tier: ssmTypes.ParameterTierIntelligentTiering, - Type: ssmTypes.ParameterTypeSecureString, - }, - ) - return err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !ts_omit_aws + +// Package awsstore contains an ipn.StateStore implementation using AWS SSM. +package awsstore + +import ( + "context" + "errors" + "fmt" + "regexp" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ssm" + ssmTypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" + "tailscale.com/ipn" + "tailscale.com/ipn/store/mem" + "tailscale.com/types/logger" +) + +const ( + parameterNameRxStr = `^parameter(/.*)` +) + +var parameterNameRx = regexp.MustCompile(parameterNameRxStr) + +// awsSSMClient is an interface allowing us to mock the couple of +// API calls we are leveraging with the AWSStore provider +type awsSSMClient interface { + GetParameter(ctx context.Context, + params *ssm.GetParameterInput, + optFns ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) + + PutParameter(ctx context.Context, + params *ssm.PutParameterInput, + optFns ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) +} + +// store is a store which leverages AWS SSM parameter store +// to persist the state +type awsStore struct { + ssmClient awsSSMClient + ssmARN arn.ARN + + memory mem.Store +} + +// New returns a new ipn.StateStore using the AWS SSM storage +// location given by ssmARN. +// +// Note that we store the entire store in a single parameter +// key, therefore if the state is above 8kb, it can cause +// Tailscaled to only only store new state in-memory and +// restarting Tailscaled can fail until you delete your state +// from the AWS Parameter Store. +func New(_ logger.Logf, ssmARN string) (ipn.StateStore, error) { + return newStore(ssmARN, nil) +} + +// newStore is NewStore, but for tests. If client is non-nil, it's +// used instead of making one. +func newStore(ssmARN string, client awsSSMClient) (ipn.StateStore, error) { + s := &awsStore{ + ssmClient: client, + } + + var err error + + // Parse the ARN + if s.ssmARN, err = arn.Parse(ssmARN); err != nil { + return nil, fmt.Errorf("unable to parse the ARN correctly: %v", err) + } + + // Validate the ARN corresponds to the SSM service + if s.ssmARN.Service != "ssm" { + return nil, fmt.Errorf("invalid service %q, expected 'ssm'", s.ssmARN.Service) + } + + // Validate the ARN corresponds to a parameter store resource + if !parameterNameRx.MatchString(s.ssmARN.Resource) { + return nil, fmt.Errorf("invalid resource %q, expected to match %v", s.ssmARN.Resource, parameterNameRxStr) + } + + if s.ssmClient == nil { + var cfg aws.Config + if cfg, err = config.LoadDefaultConfig( + context.TODO(), + config.WithRegion(s.ssmARN.Region), + ); err != nil { + return nil, err + } + s.ssmClient = ssm.NewFromConfig(cfg) + } + + // Hydrate cache with the potentially current state + if err := s.LoadState(); err != nil { + return nil, err + } + return s, nil + +} + +// LoadState attempts to read the state from AWS SSM parameter store key. +func (s *awsStore) LoadState() error { + param, err := s.ssmClient.GetParameter( + context.TODO(), + &ssm.GetParameterInput{ + Name: aws.String(s.ParameterName()), + WithDecryption: aws.Bool(true), + }, + ) + + if err != nil { + var pnf *ssmTypes.ParameterNotFound + if errors.As(err, &pnf) { + // Create the parameter as it does not exist yet + // and return directly as it is defacto empty + return s.persistState() + } + return err + } + + // Load the content in-memory + return s.memory.LoadFromJSON([]byte(*param.Parameter.Value)) +} + +// ParameterName returns the parameter name extracted from +// the provided ARN +func (s *awsStore) ParameterName() (name string) { + values := parameterNameRx.FindStringSubmatch(s.ssmARN.Resource) + if len(values) == 2 { + name = values[1] + } + return +} + +// String returns the awsStore and the ARN of the SSM parameter store +// configured to store the state +func (s *awsStore) String() string { return fmt.Sprintf("awsStore(%q)", s.ssmARN.String()) } + +// ReadState implements the Store interface. +func (s *awsStore) ReadState(id ipn.StateKey) (bs []byte, err error) { + return s.memory.ReadState(id) +} + +// WriteState implements the Store interface. +func (s *awsStore) WriteState(id ipn.StateKey, bs []byte) (err error) { + // Write the state in-memory + if err = s.memory.WriteState(id, bs); err != nil { + return + } + + // Persist the state in AWS SSM parameter store + return s.persistState() +} + +// PersistState saves the states into the AWS SSM parameter store +func (s *awsStore) persistState() error { + // Generate JSON from in-memory cache + bs, err := s.memory.ExportToJSON() + if err != nil { + return err + } + + // Store in AWS SSM parameter store. + // + // We use intelligent tiering so that when the state is below 4kb, it uses Standard tiering + // which is free. However, if it exceeds 4kb it switches the parameter to advanced tiering + // doubling the capacity to 8kb per the following docs: + // https://aws.amazon.com/about-aws/whats-new/2019/08/aws-systems-manager-parameter-store-announces-intelligent-tiering-to-enable-automatic-parameter-tier-selection/ + _, err = s.ssmClient.PutParameter( + context.TODO(), + &ssm.PutParameterInput{ + Name: aws.String(s.ParameterName()), + Value: aws.String(string(bs)), + Overwrite: aws.Bool(true), + Tier: ssmTypes.ParameterTierIntelligentTiering, + Type: ssmTypes.ParameterTypeSecureString, + }, + ) + return err +} diff --git a/ipn/store/awsstore/store_aws_stub.go b/ipn/store/awsstore/store_aws_stub.go index 7be8b858d..8d2156ce9 100644 --- a/ipn/store/awsstore/store_aws_stub.go +++ b/ipn/store/awsstore/store_aws_stub.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux || ts_omit_aws - -package awsstore - -import ( - "fmt" - "runtime" - - "tailscale.com/ipn" - "tailscale.com/types/logger" -) - -func New(logger.Logf, string) (ipn.StateStore, error) { - return nil, fmt.Errorf("AWS store is not supported on %v", runtime.GOOS) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux || ts_omit_aws + +package awsstore + +import ( + "fmt" + "runtime" + + "tailscale.com/ipn" + "tailscale.com/types/logger" +) + +func New(logger.Logf, string) (ipn.StateStore, error) { + return nil, fmt.Errorf("AWS store is not supported on %v", runtime.GOOS) +} diff --git a/ipn/store/awsstore/store_aws_test.go b/ipn/store/awsstore/store_aws_test.go index 54e6e18cb..f6c8fedb3 100644 --- a/ipn/store/awsstore/store_aws_test.go +++ b/ipn/store/awsstore/store_aws_test.go @@ -1,164 +1,164 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package awsstore - -import ( - "context" - "testing" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/aws/arn" - "github.com/aws/aws-sdk-go-v2/service/ssm" - ssmTypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" - "tailscale.com/ipn" - "tailscale.com/tstest" -) - -type mockedAWSSSMClient struct { - value string -} - -func (sp *mockedAWSSSMClient) GetParameter(_ context.Context, input *ssm.GetParameterInput, _ ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { - output := new(ssm.GetParameterOutput) - if sp.value == "" { - return output, &ssmTypes.ParameterNotFound{} - } - - output.Parameter = &ssmTypes.Parameter{ - Value: aws.String(sp.value), - } - - return output, nil -} - -func (sp *mockedAWSSSMClient) PutParameter(_ context.Context, input *ssm.PutParameterInput, _ ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) { - sp.value = *input.Value - return new(ssm.PutParameterOutput), nil -} - -func TestAWSStoreString(t *testing.T) { - store := &awsStore{ - ssmARN: arn.ARN{ - Service: "ssm", - Region: "eu-west-1", - AccountID: "123456789", - Resource: "parameter/foo", - }, - } - want := "awsStore(\"arn::ssm:eu-west-1:123456789:parameter/foo\")" - if got := store.String(); got != want { - t.Errorf("AWSStore.String = %q; want %q", got, want) - } -} - -func TestNewAWSStore(t *testing.T) { - tstest.PanicOnLog() - - mc := &mockedAWSSSMClient{} - storeParameterARN := arn.ARN{ - Service: "ssm", - Region: "eu-west-1", - AccountID: "123456789", - Resource: "parameter/foo", - } - - s, err := newStore(storeParameterARN.String(), mc) - if err != nil { - t.Fatalf("creating aws store failed: %v", err) - } - testStoreSemantics(t, s) - - // Build a brand new file store and check that both IDs written - // above are still there. - s2, err := newStore(storeParameterARN.String(), mc) - if err != nil { - t.Fatalf("creating second aws store failed: %v", err) - } - store2 := s.(*awsStore) - - // This is specific to the test, with the non-mocked API, LoadState() should - // have been already called and successful as no err is returned from NewAWSStore() - s2.(*awsStore).LoadState() - - expected := map[ipn.StateKey]string{ - "foo": "bar", - "baz": "quux", - } - for id, want := range expected { - bs, err := store2.ReadState(id) - if err != nil { - t.Errorf("reading %q (2nd store): %v", id, err) - } - if string(bs) != want { - t.Errorf("reading %q (2nd store): got %q, want %q", id, string(bs), want) - } - } -} - -func testStoreSemantics(t *testing.T, store ipn.StateStore) { - t.Helper() - - tests := []struct { - // if true, data is data to write. If false, data is expected - // output of read. - write bool - id ipn.StateKey - data string - // If write=false, true if we expect a not-exist error. - notExists bool - }{ - { - id: "foo", - notExists: true, - }, - { - write: true, - id: "foo", - data: "bar", - }, - { - id: "foo", - data: "bar", - }, - { - id: "baz", - notExists: true, - }, - { - write: true, - id: "baz", - data: "quux", - }, - { - id: "foo", - data: "bar", - }, - { - id: "baz", - data: "quux", - }, - } - - for _, test := range tests { - if test.write { - if err := store.WriteState(test.id, []byte(test.data)); err != nil { - t.Errorf("writing %q to %q: %v", test.data, test.id, err) - } - } else { - bs, err := store.ReadState(test.id) - if err != nil { - if test.notExists && err == ipn.ErrStateNotExist { - continue - } - t.Errorf("reading %q: %v", test.id, err) - continue - } - if string(bs) != test.data { - t.Errorf("reading %q: got %q, want %q", test.id, string(bs), test.data) - } - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package awsstore + +import ( + "context" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/aws/aws-sdk-go-v2/service/ssm" + ssmTypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" + "tailscale.com/ipn" + "tailscale.com/tstest" +) + +type mockedAWSSSMClient struct { + value string +} + +func (sp *mockedAWSSSMClient) GetParameter(_ context.Context, input *ssm.GetParameterInput, _ ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) { + output := new(ssm.GetParameterOutput) + if sp.value == "" { + return output, &ssmTypes.ParameterNotFound{} + } + + output.Parameter = &ssmTypes.Parameter{ + Value: aws.String(sp.value), + } + + return output, nil +} + +func (sp *mockedAWSSSMClient) PutParameter(_ context.Context, input *ssm.PutParameterInput, _ ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) { + sp.value = *input.Value + return new(ssm.PutParameterOutput), nil +} + +func TestAWSStoreString(t *testing.T) { + store := &awsStore{ + ssmARN: arn.ARN{ + Service: "ssm", + Region: "eu-west-1", + AccountID: "123456789", + Resource: "parameter/foo", + }, + } + want := "awsStore(\"arn::ssm:eu-west-1:123456789:parameter/foo\")" + if got := store.String(); got != want { + t.Errorf("AWSStore.String = %q; want %q", got, want) + } +} + +func TestNewAWSStore(t *testing.T) { + tstest.PanicOnLog() + + mc := &mockedAWSSSMClient{} + storeParameterARN := arn.ARN{ + Service: "ssm", + Region: "eu-west-1", + AccountID: "123456789", + Resource: "parameter/foo", + } + + s, err := newStore(storeParameterARN.String(), mc) + if err != nil { + t.Fatalf("creating aws store failed: %v", err) + } + testStoreSemantics(t, s) + + // Build a brand new file store and check that both IDs written + // above are still there. + s2, err := newStore(storeParameterARN.String(), mc) + if err != nil { + t.Fatalf("creating second aws store failed: %v", err) + } + store2 := s.(*awsStore) + + // This is specific to the test, with the non-mocked API, LoadState() should + // have been already called and successful as no err is returned from NewAWSStore() + s2.(*awsStore).LoadState() + + expected := map[ipn.StateKey]string{ + "foo": "bar", + "baz": "quux", + } + for id, want := range expected { + bs, err := store2.ReadState(id) + if err != nil { + t.Errorf("reading %q (2nd store): %v", id, err) + } + if string(bs) != want { + t.Errorf("reading %q (2nd store): got %q, want %q", id, string(bs), want) + } + } +} + +func testStoreSemantics(t *testing.T, store ipn.StateStore) { + t.Helper() + + tests := []struct { + // if true, data is data to write. If false, data is expected + // output of read. + write bool + id ipn.StateKey + data string + // If write=false, true if we expect a not-exist error. + notExists bool + }{ + { + id: "foo", + notExists: true, + }, + { + write: true, + id: "foo", + data: "bar", + }, + { + id: "foo", + data: "bar", + }, + { + id: "baz", + notExists: true, + }, + { + write: true, + id: "baz", + data: "quux", + }, + { + id: "foo", + data: "bar", + }, + { + id: "baz", + data: "quux", + }, + } + + for _, test := range tests { + if test.write { + if err := store.WriteState(test.id, []byte(test.data)); err != nil { + t.Errorf("writing %q to %q: %v", test.data, test.id, err) + } + } else { + bs, err := store.ReadState(test.id) + if err != nil { + if test.notExists && err == ipn.ErrStateNotExist { + continue + } + t.Errorf("reading %q: %v", test.id, err) + continue + } + if string(bs) != test.data { + t.Errorf("reading %q: got %q, want %q", test.id, string(bs), test.data) + } + } + } +} diff --git a/ipn/store/stores_test.go b/ipn/store/stores_test.go index 69aa79193..ea09e6ea6 100644 --- a/ipn/store/stores_test.go +++ b/ipn/store/stores_test.go @@ -1,179 +1,179 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package store - -import ( - "path/filepath" - "testing" - - "tailscale.com/ipn" - "tailscale.com/ipn/store/mem" - "tailscale.com/tstest" - "tailscale.com/types/logger" -) - -func TestNewStore(t *testing.T) { - regOnce.Do(registerDefaultStores) - t.Cleanup(func() { - knownStores = map[string]Provider{} - registerDefaultStores() - }) - knownStores = map[string]Provider{} - - type store1 struct { - ipn.StateStore - path string - } - - type store2 struct { - ipn.StateStore - path string - } - - Register("arn:", func(_ logger.Logf, path string) (ipn.StateStore, error) { - return &store1{new(mem.Store), path}, nil - }) - Register("kube:", func(_ logger.Logf, path string) (ipn.StateStore, error) { - return &store2{new(mem.Store), path}, nil - }) - Register("mem:", func(_ logger.Logf, path string) (ipn.StateStore, error) { - return new(mem.Store), nil - }) - - path := "mem:abcd" - if s, err := New(t.Logf, path); err != nil { - t.Fatalf("%q: %v", path, err) - } else if _, ok := s.(*mem.Store); !ok { - t.Fatalf("%q: got: %T, want: %T", path, s, new(mem.Store)) - } - - path = "arn:foo" - if s, err := New(t.Logf, path); err != nil { - t.Fatalf("%q: %v", path, err) - } else if _, ok := s.(*store1); !ok { - t.Fatalf("%q: got: %T, want: %T", path, s, new(store1)) - } - - path = "kube:abcd" - if s, err := New(t.Logf, path); err != nil { - t.Fatalf("%q: %v", path, err) - } else if _, ok := s.(*store2); !ok { - t.Fatalf("%q: got: %T, want: %T", path, s, new(store2)) - } - - path = filepath.Join(t.TempDir(), "state") - if s, err := New(t.Logf, path); err != nil { - t.Fatalf("%q: %v", path, err) - } else if _, ok := s.(*FileStore); !ok { - t.Fatalf("%q: got: %T, want: %T", path, s, new(FileStore)) - } -} - -func testStoreSemantics(t *testing.T, store ipn.StateStore) { - t.Helper() - - tests := []struct { - // if true, data is data to write. If false, data is expected - // output of read. - write bool - id ipn.StateKey - data string - // If write=false, true if we expect a not-exist error. - notExists bool - }{ - { - id: "foo", - notExists: true, - }, - { - write: true, - id: "foo", - data: "bar", - }, - { - id: "foo", - data: "bar", - }, - { - id: "baz", - notExists: true, - }, - { - write: true, - id: "baz", - data: "quux", - }, - { - id: "foo", - data: "bar", - }, - { - id: "baz", - data: "quux", - }, - } - - for _, test := range tests { - if test.write { - if err := store.WriteState(test.id, []byte(test.data)); err != nil { - t.Errorf("writing %q to %q: %v", test.data, test.id, err) - } - } else { - bs, err := store.ReadState(test.id) - if err != nil { - if test.notExists && err == ipn.ErrStateNotExist { - continue - } - t.Errorf("reading %q: %v", test.id, err) - continue - } - if string(bs) != test.data { - t.Errorf("reading %q: got %q, want %q", test.id, string(bs), test.data) - } - } - } -} - -func TestMemoryStore(t *testing.T) { - tstest.PanicOnLog() - - store := new(mem.Store) - testStoreSemantics(t, store) -} - -func TestFileStore(t *testing.T) { - tstest.PanicOnLog() - - dir := t.TempDir() - path := filepath.Join(dir, "test-file-store.conf") - - store, err := NewFileStore(nil, path) - if err != nil { - t.Fatalf("creating file store failed: %v", err) - } - - testStoreSemantics(t, store) - - // Build a brand new file store and check that both IDs written - // above are still there. - store, err = NewFileStore(nil, path) - if err != nil { - t.Fatalf("creating second file store failed: %v", err) - } - - expected := map[ipn.StateKey]string{ - "foo": "bar", - "baz": "quux", - } - for key, want := range expected { - bs, err := store.ReadState(key) - if err != nil { - t.Errorf("reading %q (2nd store): %v", key, err) - continue - } - if string(bs) != want { - t.Errorf("reading %q (2nd store): got %q, want %q", key, bs, want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package store + +import ( + "path/filepath" + "testing" + + "tailscale.com/ipn" + "tailscale.com/ipn/store/mem" + "tailscale.com/tstest" + "tailscale.com/types/logger" +) + +func TestNewStore(t *testing.T) { + regOnce.Do(registerDefaultStores) + t.Cleanup(func() { + knownStores = map[string]Provider{} + registerDefaultStores() + }) + knownStores = map[string]Provider{} + + type store1 struct { + ipn.StateStore + path string + } + + type store2 struct { + ipn.StateStore + path string + } + + Register("arn:", func(_ logger.Logf, path string) (ipn.StateStore, error) { + return &store1{new(mem.Store), path}, nil + }) + Register("kube:", func(_ logger.Logf, path string) (ipn.StateStore, error) { + return &store2{new(mem.Store), path}, nil + }) + Register("mem:", func(_ logger.Logf, path string) (ipn.StateStore, error) { + return new(mem.Store), nil + }) + + path := "mem:abcd" + if s, err := New(t.Logf, path); err != nil { + t.Fatalf("%q: %v", path, err) + } else if _, ok := s.(*mem.Store); !ok { + t.Fatalf("%q: got: %T, want: %T", path, s, new(mem.Store)) + } + + path = "arn:foo" + if s, err := New(t.Logf, path); err != nil { + t.Fatalf("%q: %v", path, err) + } else if _, ok := s.(*store1); !ok { + t.Fatalf("%q: got: %T, want: %T", path, s, new(store1)) + } + + path = "kube:abcd" + if s, err := New(t.Logf, path); err != nil { + t.Fatalf("%q: %v", path, err) + } else if _, ok := s.(*store2); !ok { + t.Fatalf("%q: got: %T, want: %T", path, s, new(store2)) + } + + path = filepath.Join(t.TempDir(), "state") + if s, err := New(t.Logf, path); err != nil { + t.Fatalf("%q: %v", path, err) + } else if _, ok := s.(*FileStore); !ok { + t.Fatalf("%q: got: %T, want: %T", path, s, new(FileStore)) + } +} + +func testStoreSemantics(t *testing.T, store ipn.StateStore) { + t.Helper() + + tests := []struct { + // if true, data is data to write. If false, data is expected + // output of read. + write bool + id ipn.StateKey + data string + // If write=false, true if we expect a not-exist error. + notExists bool + }{ + { + id: "foo", + notExists: true, + }, + { + write: true, + id: "foo", + data: "bar", + }, + { + id: "foo", + data: "bar", + }, + { + id: "baz", + notExists: true, + }, + { + write: true, + id: "baz", + data: "quux", + }, + { + id: "foo", + data: "bar", + }, + { + id: "baz", + data: "quux", + }, + } + + for _, test := range tests { + if test.write { + if err := store.WriteState(test.id, []byte(test.data)); err != nil { + t.Errorf("writing %q to %q: %v", test.data, test.id, err) + } + } else { + bs, err := store.ReadState(test.id) + if err != nil { + if test.notExists && err == ipn.ErrStateNotExist { + continue + } + t.Errorf("reading %q: %v", test.id, err) + continue + } + if string(bs) != test.data { + t.Errorf("reading %q: got %q, want %q", test.id, string(bs), test.data) + } + } + } +} + +func TestMemoryStore(t *testing.T) { + tstest.PanicOnLog() + + store := new(mem.Store) + testStoreSemantics(t, store) +} + +func TestFileStore(t *testing.T) { + tstest.PanicOnLog() + + dir := t.TempDir() + path := filepath.Join(dir, "test-file-store.conf") + + store, err := NewFileStore(nil, path) + if err != nil { + t.Fatalf("creating file store failed: %v", err) + } + + testStoreSemantics(t, store) + + // Build a brand new file store and check that both IDs written + // above are still there. + store, err = NewFileStore(nil, path) + if err != nil { + t.Fatalf("creating second file store failed: %v", err) + } + + expected := map[ipn.StateKey]string{ + "foo": "bar", + "baz": "quux", + } + for key, want := range expected { + bs, err := store.ReadState(key) + if err != nil { + t.Errorf("reading %q (2nd store): %v", key, err) + continue + } + if string(bs) != want { + t.Errorf("reading %q (2nd store): got %q, want %q", key, bs, want) + } + } +} diff --git a/ipn/store_test.go b/ipn/store_test.go index 330f67969..fcc082d8a 100644 --- a/ipn/store_test.go +++ b/ipn/store_test.go @@ -1,48 +1,48 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipn - -import ( - "bytes" - "sync" - "testing" - - "tailscale.com/util/mak" -) - -type memStore struct { - mu sync.Mutex - writes int - m map[StateKey][]byte -} - -func (s *memStore) ReadState(k StateKey) ([]byte, error) { - s.mu.Lock() - defer s.mu.Unlock() - return bytes.Clone(s.m[k]), nil -} - -func (s *memStore) WriteState(k StateKey, v []byte) error { - s.mu.Lock() - defer s.mu.Unlock() - mak.Set(&s.m, k, bytes.Clone(v)) - s.writes++ - return nil -} - -func TestWriteState(t *testing.T) { - var ss StateStore = new(memStore) - WriteState(ss, "foo", []byte("bar")) - WriteState(ss, "foo", []byte("bar")) - got, err := ss.ReadState("foo") - if err != nil { - t.Fatal(err) - } - if want := []byte("bar"); !bytes.Equal(got, want) { - t.Errorf("got %q; want %q", got, want) - } - if got, want := ss.(*memStore).writes, 1; got != want { - t.Errorf("got %d writes; want %d", got, want) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipn + +import ( + "bytes" + "sync" + "testing" + + "tailscale.com/util/mak" +) + +type memStore struct { + mu sync.Mutex + writes int + m map[StateKey][]byte +} + +func (s *memStore) ReadState(k StateKey) ([]byte, error) { + s.mu.Lock() + defer s.mu.Unlock() + return bytes.Clone(s.m[k]), nil +} + +func (s *memStore) WriteState(k StateKey, v []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + mak.Set(&s.m, k, bytes.Clone(v)) + s.writes++ + return nil +} + +func TestWriteState(t *testing.T) { + var ss StateStore = new(memStore) + WriteState(ss, "foo", []byte("bar")) + WriteState(ss, "foo", []byte("bar")) + got, err := ss.ReadState("foo") + if err != nil { + t.Fatal(err) + } + if want := []byte("bar"); !bytes.Equal(got, want) { + t.Errorf("got %q; want %q", got, want) + } + if got, want := ss.(*memStore).writes, 1; got != want { + t.Errorf("got %d writes; want %d", got, want) + } +} diff --git a/jsondb/db.go b/jsondb/db.go index c45c1f819..68bb05af4 100644 --- a/jsondb/db.go +++ b/jsondb/db.go @@ -1,57 +1,57 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package jsondb provides a trivial "database": a Go object saved to -// disk as JSON. -package jsondb - -import ( - "encoding/json" - "errors" - "io/fs" - "os" - - "tailscale.com/atomicfile" -) - -// DB is a database backed by a JSON file. -type DB[T any] struct { - // Data is the contents of the database. - Data *T - - path string -} - -// Open opens the database at path, creating it with a zero value if -// necessary. -func Open[T any](path string) (*DB[T], error) { - bs, err := os.ReadFile(path) - if errors.Is(err, fs.ErrNotExist) { - return &DB[T]{ - Data: new(T), - path: path, - }, nil - } else if err != nil { - return nil, err - } - - var val T - if err := json.Unmarshal(bs, &val); err != nil { - return nil, err - } - - return &DB[T]{ - Data: &val, - path: path, - }, nil -} - -// Save writes db.Data back to disk. -func (db *DB[T]) Save() error { - bs, err := json.Marshal(db.Data) - if err != nil { - return err - } - - return atomicfile.WriteFile(db.path, bs, 0600) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package jsondb provides a trivial "database": a Go object saved to +// disk as JSON. +package jsondb + +import ( + "encoding/json" + "errors" + "io/fs" + "os" + + "tailscale.com/atomicfile" +) + +// DB is a database backed by a JSON file. +type DB[T any] struct { + // Data is the contents of the database. + Data *T + + path string +} + +// Open opens the database at path, creating it with a zero value if +// necessary. +func Open[T any](path string) (*DB[T], error) { + bs, err := os.ReadFile(path) + if errors.Is(err, fs.ErrNotExist) { + return &DB[T]{ + Data: new(T), + path: path, + }, nil + } else if err != nil { + return nil, err + } + + var val T + if err := json.Unmarshal(bs, &val); err != nil { + return nil, err + } + + return &DB[T]{ + Data: &val, + path: path, + }, nil +} + +// Save writes db.Data back to disk. +func (db *DB[T]) Save() error { + bs, err := json.Marshal(db.Data) + if err != nil { + return err + } + + return atomicfile.WriteFile(db.path, bs, 0600) +} diff --git a/jsondb/db_test.go b/jsondb/db_test.go index a78b15b4f..655754f38 100644 --- a/jsondb/db_test.go +++ b/jsondb/db_test.go @@ -1,55 +1,55 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package jsondb - -import ( - "log" - "os" - "path/filepath" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestDB(t *testing.T) { - dir, err := os.MkdirTemp("", "db-test") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(dir) - - path := filepath.Join(dir, "db.json") - db, err := Open[testDB](path) - if err != nil { - t.Fatalf("creating empty DB: %v", err) - } - - if diff := cmp.Diff(db.Data, &testDB{}, cmp.AllowUnexported(testDB{})); diff != "" { - t.Fatalf("unexpected empty DB content (-got+want):\n%s", diff) - } - db.Data.MyString = "test" - db.Data.unexported = "don't keep" - db.Data.AnInt = 42 - if err := db.Save(); err != nil { - t.Fatalf("saving database: %v", err) - } - - db2, err := Open[testDB](path) - if err != nil { - log.Fatalf("opening DB again: %v", err) - } - want := &testDB{ - MyString: "test", - AnInt: 42, - } - if diff := cmp.Diff(db2.Data, want, cmp.AllowUnexported(testDB{})); diff != "" { - t.Fatalf("unexpected saved DB content (-got+want):\n%s", diff) - } -} - -type testDB struct { - MyString string - unexported string - AnInt int64 -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package jsondb + +import ( + "log" + "os" + "path/filepath" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestDB(t *testing.T) { + dir, err := os.MkdirTemp("", "db-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + path := filepath.Join(dir, "db.json") + db, err := Open[testDB](path) + if err != nil { + t.Fatalf("creating empty DB: %v", err) + } + + if diff := cmp.Diff(db.Data, &testDB{}, cmp.AllowUnexported(testDB{})); diff != "" { + t.Fatalf("unexpected empty DB content (-got+want):\n%s", diff) + } + db.Data.MyString = "test" + db.Data.unexported = "don't keep" + db.Data.AnInt = 42 + if err := db.Save(); err != nil { + t.Fatalf("saving database: %v", err) + } + + db2, err := Open[testDB](path) + if err != nil { + log.Fatalf("opening DB again: %v", err) + } + want := &testDB{ + MyString: "test", + AnInt: 42, + } + if diff := cmp.Diff(db2.Data, want, cmp.AllowUnexported(testDB{})); diff != "" { + t.Fatalf("unexpected saved DB content (-got+want):\n%s", diff) + } +} + +type testDB struct { + MyString string + unexported string + AnInt int64 +} diff --git a/licenses/licenses.go b/licenses/licenses.go index 3ec701321..5e59edb9f 100644 --- a/licenses/licenses.go +++ b/licenses/licenses.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package licenses provides utilities for working with open source licenses. -package licenses - -import "runtime" - -// LicensesURL returns the absolute URL containing open source license information for the current platform. -func LicensesURL() string { - switch runtime.GOOS { - case "android": - return "https://tailscale.com/licenses/android" - case "darwin", "ios": - return "https://tailscale.com/licenses/apple" - case "windows": - return "https://tailscale.com/licenses/windows" - default: - return "https://tailscale.com/licenses/tailscale" - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package licenses provides utilities for working with open source licenses. +package licenses + +import "runtime" + +// LicensesURL returns the absolute URL containing open source license information for the current platform. +func LicensesURL() string { + switch runtime.GOOS { + case "android": + return "https://tailscale.com/licenses/android" + case "darwin", "ios": + return "https://tailscale.com/licenses/apple" + case "windows": + return "https://tailscale.com/licenses/windows" + default: + return "https://tailscale.com/licenses/tailscale" + } +} diff --git a/log/filelogger/log.go b/log/filelogger/log.go index 9d7097eb8..599e5237b 100644 --- a/log/filelogger/log.go +++ b/log/filelogger/log.go @@ -1,228 +1,228 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package filelogger provides localdisk log writing & rotation, primarily for Windows -// clients. (We get this for free on other platforms.) -package filelogger - -import ( - "bytes" - "fmt" - "log" - "os" - "path/filepath" - "runtime" - "strings" - "sync" - "time" - - "tailscale.com/types/logger" -) - -const ( - maxSize = 100 << 20 - maxFiles = 50 -) - -// New returns a logf wrapper that appends to local disk log -// files on Windows, rotating old log files as needed to stay under -// file count & byte limits. -func New(fileBasePrefix, logID string, logf logger.Logf) logger.Logf { - if runtime.GOOS != "windows" { - panic("not yet supported on any platform except Windows") - } - if logf == nil { - panic("nil logf") - } - dir := filepath.Join(os.Getenv("ProgramData"), "Tailscale", "Logs") - - if err := os.MkdirAll(dir, 0700); err != nil { - log.Printf("failed to create local log directory; not writing logs to disk: %v", err) - return logf - } - logf("local disk logdir: %v", dir) - lfw := &logFileWriter{ - fileBasePrefix: fileBasePrefix, - logID: logID, - dir: dir, - wrappedLogf: logf, - } - return lfw.Logf -} - -// logFileWriter is the state for the log writer & rotator. -type logFileWriter struct { - dir string // e.g. `C:\Users\FooBarUser\AppData\Local\Tailscale\Logs` - logID string // hex logID - fileBasePrefix string // e.g. "tailscale-service" or "tailscale-gui" - wrappedLogf logger.Logf // underlying logger to send to - - mu sync.Mutex // guards following - buf bytes.Buffer // scratch buffer to avoid allocs - fday civilDay // day that f was opened; zero means no file yet open - f *os.File // file currently opened for append -} - -// civilDay is a year, month, and day in the local timezone. -// It's a comparable value type. -type civilDay struct { - year int - month time.Month - day int -} - -func dayOf(t time.Time) civilDay { - return civilDay{t.Year(), t.Month(), t.Day()} -} - -func (w *logFileWriter) Logf(format string, a ...any) { - w.mu.Lock() - defer w.mu.Unlock() - - w.buf.Reset() - fmt.Fprintf(&w.buf, format, a...) - if w.buf.Len() == 0 { - return - } - out := w.buf.Bytes() - w.wrappedLogf("%s", out) - - // Make sure there's a final newline before we write to the log file. - if out[len(out)-1] != '\n' { - w.buf.WriteByte('\n') - out = w.buf.Bytes() - } - - w.appendToFileLocked(out) -} - -// out should end in a newline. -// w.mu must be held. -func (w *logFileWriter) appendToFileLocked(out []byte) { - now := time.Now() - day := dayOf(now) - if w.fday != day { - w.startNewFileLocked() - } - out = removeDatePrefix(out) - if w.f != nil { - // RFC3339Nano but with a fixed number (3) of nanosecond digits: - const formatPre = "2006-01-02T15:04:05" - const formatPost = "Z07:00" - fmt.Fprintf(w.f, "%s.%03d%s: %s", - now.Format(formatPre), - now.Nanosecond()/int(time.Millisecond/time.Nanosecond), - now.Format(formatPost), - out) - } -} - -func isNum(b byte) bool { return '0' <= b && b <= '9' } - -// removeDatePrefix returns a subslice of v with the log package's -// standard datetime prefix format removed, if present. -func removeDatePrefix(v []byte) []byte { - const format = "2009/01/23 01:23:23 " - if len(v) < len(format) { - return v - } - for i, b := range v[:len(format)] { - fb := format[i] - if isNum(fb) { - if !isNum(b) { - return v - } - continue - } - if b != fb { - return v - } - } - return v[len(format):] -} - -// startNewFileLocked opens a new log file for writing -// and also cleans up any old files. -// -// w.mu must be held. -func (w *logFileWriter) startNewFileLocked() { - var oldName string - if w.f != nil { - oldName = filepath.Base(w.f.Name()) - w.f.Close() - w.f = nil - w.fday = civilDay{} - } - w.cleanLocked() - - now := time.Now() - day := dayOf(now) - name := filepath.Join(w.dir, fmt.Sprintf("%s-%04d%02d%02dT%02d%02d%02d-%d.txt", - w.fileBasePrefix, - day.year, - day.month, - day.day, - now.Hour(), - now.Minute(), - now.Second(), - now.Unix())) - var err error - w.f, err = os.Create(name) - if err != nil { - w.wrappedLogf("failed to create log file: %v", err) - return - } - if oldName != "" { - fmt.Fprintf(w.f, "(logID %q; continued from log file %s)\n", w.logID, oldName) - } else { - fmt.Fprintf(w.f, "(logID %q)\n", w.logID) - } - w.fday = day -} - -// cleanLocked cleans up old log files. -// -// w.mu must be held. -func (w *logFileWriter) cleanLocked() { - entries, _ := os.ReadDir(w.dir) - prefix := w.fileBasePrefix + "-" - fileSize := map[string]int64{} - var files []string - var sumSize int64 - for _, entry := range entries { - fi, err := entry.Info() - if err != nil { - w.wrappedLogf("error getting log file info: %v", err) - continue - } - - baseName := filepath.Base(fi.Name()) - if !strings.HasPrefix(baseName, prefix) { - continue - } - size := fi.Size() - fileSize[baseName] = size - sumSize += size - files = append(files, baseName) - } - if sumSize > maxSize { - w.wrappedLogf("cleaning log files; sum byte count %d > %d", sumSize, maxSize) - } - if len(files) > maxFiles { - w.wrappedLogf("cleaning log files; number of files %d > %d", len(files), maxFiles) - } - for (sumSize > maxSize || len(files) > maxFiles) && len(files) > 0 { - target := files[0] - files = files[1:] - - targetSize := fileSize[target] - targetFull := filepath.Join(w.dir, target) - err := os.Remove(targetFull) - if err != nil { - w.wrappedLogf("error cleaning log file: %v", err) - } else { - sumSize -= targetSize - w.wrappedLogf("cleaned log file %s (size %d); new bytes=%v, files=%v", targetFull, targetSize, sumSize, len(files)) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package filelogger provides localdisk log writing & rotation, primarily for Windows +// clients. (We get this for free on other platforms.) +package filelogger + +import ( + "bytes" + "fmt" + "log" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + "time" + + "tailscale.com/types/logger" +) + +const ( + maxSize = 100 << 20 + maxFiles = 50 +) + +// New returns a logf wrapper that appends to local disk log +// files on Windows, rotating old log files as needed to stay under +// file count & byte limits. +func New(fileBasePrefix, logID string, logf logger.Logf) logger.Logf { + if runtime.GOOS != "windows" { + panic("not yet supported on any platform except Windows") + } + if logf == nil { + panic("nil logf") + } + dir := filepath.Join(os.Getenv("ProgramData"), "Tailscale", "Logs") + + if err := os.MkdirAll(dir, 0700); err != nil { + log.Printf("failed to create local log directory; not writing logs to disk: %v", err) + return logf + } + logf("local disk logdir: %v", dir) + lfw := &logFileWriter{ + fileBasePrefix: fileBasePrefix, + logID: logID, + dir: dir, + wrappedLogf: logf, + } + return lfw.Logf +} + +// logFileWriter is the state for the log writer & rotator. +type logFileWriter struct { + dir string // e.g. `C:\Users\FooBarUser\AppData\Local\Tailscale\Logs` + logID string // hex logID + fileBasePrefix string // e.g. "tailscale-service" or "tailscale-gui" + wrappedLogf logger.Logf // underlying logger to send to + + mu sync.Mutex // guards following + buf bytes.Buffer // scratch buffer to avoid allocs + fday civilDay // day that f was opened; zero means no file yet open + f *os.File // file currently opened for append +} + +// civilDay is a year, month, and day in the local timezone. +// It's a comparable value type. +type civilDay struct { + year int + month time.Month + day int +} + +func dayOf(t time.Time) civilDay { + return civilDay{t.Year(), t.Month(), t.Day()} +} + +func (w *logFileWriter) Logf(format string, a ...any) { + w.mu.Lock() + defer w.mu.Unlock() + + w.buf.Reset() + fmt.Fprintf(&w.buf, format, a...) + if w.buf.Len() == 0 { + return + } + out := w.buf.Bytes() + w.wrappedLogf("%s", out) + + // Make sure there's a final newline before we write to the log file. + if out[len(out)-1] != '\n' { + w.buf.WriteByte('\n') + out = w.buf.Bytes() + } + + w.appendToFileLocked(out) +} + +// out should end in a newline. +// w.mu must be held. +func (w *logFileWriter) appendToFileLocked(out []byte) { + now := time.Now() + day := dayOf(now) + if w.fday != day { + w.startNewFileLocked() + } + out = removeDatePrefix(out) + if w.f != nil { + // RFC3339Nano but with a fixed number (3) of nanosecond digits: + const formatPre = "2006-01-02T15:04:05" + const formatPost = "Z07:00" + fmt.Fprintf(w.f, "%s.%03d%s: %s", + now.Format(formatPre), + now.Nanosecond()/int(time.Millisecond/time.Nanosecond), + now.Format(formatPost), + out) + } +} + +func isNum(b byte) bool { return '0' <= b && b <= '9' } + +// removeDatePrefix returns a subslice of v with the log package's +// standard datetime prefix format removed, if present. +func removeDatePrefix(v []byte) []byte { + const format = "2009/01/23 01:23:23 " + if len(v) < len(format) { + return v + } + for i, b := range v[:len(format)] { + fb := format[i] + if isNum(fb) { + if !isNum(b) { + return v + } + continue + } + if b != fb { + return v + } + } + return v[len(format):] +} + +// startNewFileLocked opens a new log file for writing +// and also cleans up any old files. +// +// w.mu must be held. +func (w *logFileWriter) startNewFileLocked() { + var oldName string + if w.f != nil { + oldName = filepath.Base(w.f.Name()) + w.f.Close() + w.f = nil + w.fday = civilDay{} + } + w.cleanLocked() + + now := time.Now() + day := dayOf(now) + name := filepath.Join(w.dir, fmt.Sprintf("%s-%04d%02d%02dT%02d%02d%02d-%d.txt", + w.fileBasePrefix, + day.year, + day.month, + day.day, + now.Hour(), + now.Minute(), + now.Second(), + now.Unix())) + var err error + w.f, err = os.Create(name) + if err != nil { + w.wrappedLogf("failed to create log file: %v", err) + return + } + if oldName != "" { + fmt.Fprintf(w.f, "(logID %q; continued from log file %s)\n", w.logID, oldName) + } else { + fmt.Fprintf(w.f, "(logID %q)\n", w.logID) + } + w.fday = day +} + +// cleanLocked cleans up old log files. +// +// w.mu must be held. +func (w *logFileWriter) cleanLocked() { + entries, _ := os.ReadDir(w.dir) + prefix := w.fileBasePrefix + "-" + fileSize := map[string]int64{} + var files []string + var sumSize int64 + for _, entry := range entries { + fi, err := entry.Info() + if err != nil { + w.wrappedLogf("error getting log file info: %v", err) + continue + } + + baseName := filepath.Base(fi.Name()) + if !strings.HasPrefix(baseName, prefix) { + continue + } + size := fi.Size() + fileSize[baseName] = size + sumSize += size + files = append(files, baseName) + } + if sumSize > maxSize { + w.wrappedLogf("cleaning log files; sum byte count %d > %d", sumSize, maxSize) + } + if len(files) > maxFiles { + w.wrappedLogf("cleaning log files; number of files %d > %d", len(files), maxFiles) + } + for (sumSize > maxSize || len(files) > maxFiles) && len(files) > 0 { + target := files[0] + files = files[1:] + + targetSize := fileSize[target] + targetFull := filepath.Join(w.dir, target) + err := os.Remove(targetFull) + if err != nil { + w.wrappedLogf("error cleaning log file: %v", err) + } else { + sumSize -= targetSize + w.wrappedLogf("cleaned log file %s (size %d); new bytes=%v, files=%v", targetFull, targetSize, sumSize, len(files)) + } + } +} diff --git a/log/filelogger/log_test.go b/log/filelogger/log_test.go index 27f80ab0a..dfa489637 100644 --- a/log/filelogger/log_test.go +++ b/log/filelogger/log_test.go @@ -1,27 +1,27 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package filelogger - -import "testing" - -func TestRemoveDatePrefix(t *testing.T) { - tests := []struct { - in, want string - }{ - {"", ""}, - {"\n", "\n"}, - {"2009/01/23 01:23:23", "2009/01/23 01:23:23"}, - {"2009/01/23 01:23:23 \n", "\n"}, - {"2009/01/23 01:23:23 foo\n", "foo\n"}, - {"9999/01/23 01:23:23 foo\n", "foo\n"}, - {"2009_01/23 01:23:23 had an underscore\n", "2009_01/23 01:23:23 had an underscore\n"}, - } - for i, tt := range tests { - got := removeDatePrefix([]byte(tt.in)) - if string(got) != tt.want { - t.Logf("[%d] removeDatePrefix(%q) = %q; want %q", i, tt.in, got, tt.want) - } - } - -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package filelogger + +import "testing" + +func TestRemoveDatePrefix(t *testing.T) { + tests := []struct { + in, want string + }{ + {"", ""}, + {"\n", "\n"}, + {"2009/01/23 01:23:23", "2009/01/23 01:23:23"}, + {"2009/01/23 01:23:23 \n", "\n"}, + {"2009/01/23 01:23:23 foo\n", "foo\n"}, + {"9999/01/23 01:23:23 foo\n", "foo\n"}, + {"2009_01/23 01:23:23 had an underscore\n", "2009_01/23 01:23:23 had an underscore\n"}, + } + for i, tt := range tests { + got := removeDatePrefix([]byte(tt.in)) + if string(got) != tt.want { + t.Logf("[%d] removeDatePrefix(%q) = %q; want %q", i, tt.in, got, tt.want) + } + } + +} diff --git a/logpolicy/logpolicy_test.go b/logpolicy/logpolicy_test.go index c0cdfb965..fdbfe4506 100644 --- a/logpolicy/logpolicy_test.go +++ b/logpolicy/logpolicy_test.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package logpolicy - -import ( - "os" - "reflect" - "testing" -) - -func TestLogHost(t *testing.T) { - v := reflect.ValueOf(&getLogTargetOnce).Elem() - reset := func() { - v.Set(reflect.Zero(v.Type())) - } - defer reset() - - tests := []struct { - env string - want string - }{ - {"", "log.tailscale.io"}, - {"http://foo.com", "foo.com"}, - {"https://foo.com", "foo.com"}, - {"https://foo.com/", "foo.com"}, - {"https://foo.com:123/", "foo.com"}, - } - for _, tt := range tests { - reset() - os.Setenv("TS_LOG_TARGET", tt.env) - if got := LogHost(); got != tt.want { - t.Errorf("for env %q, got %q, want %q", tt.env, got, tt.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package logpolicy + +import ( + "os" + "reflect" + "testing" +) + +func TestLogHost(t *testing.T) { + v := reflect.ValueOf(&getLogTargetOnce).Elem() + reset := func() { + v.Set(reflect.Zero(v.Type())) + } + defer reset() + + tests := []struct { + env string + want string + }{ + {"", "log.tailscale.io"}, + {"http://foo.com", "foo.com"}, + {"https://foo.com", "foo.com"}, + {"https://foo.com/", "foo.com"}, + {"https://foo.com:123/", "foo.com"}, + } + for _, tt := range tests { + reset() + os.Setenv("TS_LOG_TARGET", tt.env) + if got := LogHost(); got != tt.want { + t.Errorf("for env %q, got %q, want %q", tt.env, got, tt.want) + } + } +} diff --git a/logtail/.gitignore b/logtail/.gitignore index b262949a8..0b29b4aca 100644 --- a/logtail/.gitignore +++ b/logtail/.gitignore @@ -1,6 +1,6 @@ -*~ -*.out -/example/logadopt/logadopt -/example/logreprocess/logreprocess -/example/logtail/logtail -/logtail +*~ +*.out +/example/logadopt/logadopt +/example/logreprocess/logreprocess +/example/logtail/logtail +/logtail diff --git a/logtail/README.md b/logtail/README.md index b7b2ada34..20d22c350 100644 --- a/logtail/README.md +++ b/logtail/README.md @@ -1,10 +1,10 @@ -# Tailscale Logs Service - -This github repository contains libraries, documentation, and examples -for working with the public API of the tailscale logs service. - -For a very quick introduction to the core features, read the -[API docs](api.md) and peruse the -[logs reprocessing](./example/logreprocess/demo.sh) example. - +# Tailscale Logs Service + +This github repository contains libraries, documentation, and examples +for working with the public API of the tailscale logs service. + +For a very quick introduction to the core features, read the +[API docs](api.md) and peruse the +[logs reprocessing](./example/logreprocess/demo.sh) example. + For more information, write to info@tailscale.io. \ No newline at end of file diff --git a/logtail/api.md b/logtail/api.md index 296913ce4..8ec0b69c0 100644 --- a/logtail/api.md +++ b/logtail/api.md @@ -1,195 +1,195 @@ -# Tailscale Logs Service - -The Tailscale Logs Service defines a REST interface for configuring, storing, -retrieving, and processing log entries. - -# Overview - -HTTP requests are received at the service **base URL** -[https://log.tailscale.io](https://log.tailscale.io), and return JSON-encoded -responses using standard HTTP response codes. - -Authorization for the configuration and retrieval APIs is done with a secret -API key passed as the HTTP basic auth username. Secret keys are generated via -the web UI at base URL. An example of using basic auth with curl: - - curl -u : https://log.tailscale.io/collections - -In the future, an HTTP header will allow using MessagePack instead of JSON. - -## Collections - -Logs are organized into collections. Inside each collection is any number of -instances. - -A collection is a domain name. It is a grouping of related logs. As a -guideline, create one collection per product using subdomains of your -company's domain name. Collections must be registered with the logs service -before any attempt is made to store logs. - -## Instances - -Each collection is a set of instances. There is one instance per machine -writing logs. - -An instance has a name and a number. An instance has a **private** and -**public** ID. The private ID is a 32-byte random number encoded as hex. -The public ID is the SHA-256 hash of the private ID, encoded as hex. - -The private ID is used to write logs. The only copy of the private ID -should be on the machine sending logs. Ideally it is generated on the -machine. Logs can be written as soon as a private ID is generated. - -The public ID is used to read and adopt logs. It is designed to be sent -to a service that also holds a logs service API key. - -The tailscale logs service will store any logs for a short period of time. -To enable logs retention, the log can be **adopted** using the public ID -and a logs service API key. -Once this is done, logs will be retained long-term (for the configured -retention period). - -Unadopted instance logs are stored temporarily to help with debugging: -a misconfigured machine writing logs with a bad ID can be spotted by -reading the logs. -If a public ID is not adopted, storage is tightly capped and logs are -deleted after 12 hours. - -# APIs - -## Storage - -### `POST /c//` — send a log - -The body of the request is JSON. - -A **single message** is an object with properties: - -`{ }` - -The client may send any properties it wants in the JSON message, except -for the `logtail` property which has special meaning. Inside the logtail -object the client may only set the following properties: - -- `client_time` in the format of RFC3339: "2006-01-02T15:04:05.999999999Z07:00" - -A future version of the logs service API will also support: - -- `client_time_offset` a integer of nanoseconds since the client was reset -- `client_time_reset` a boolean if set to true resets the time offset counter - -On receipt by the server the `client_time_offset` is transformed into a -`client_time` based on the `server_time` when the first (or -client_time_reset) event was received. - -If any other properties are set in the logtail object they are moved into -the "error" field, the message is saved and a 4xx status code is returned. - -A **batch of messages** is a JSON array filled with single message objects: - -`[ { }, { }, ... ]` - -If any of the array entries are not objects, the content is converted -into a message with a `"logtail": { "error": ...}` property, saved, and -a 4xx status code is returned. - -Similarly any other request content not matching one of these formats is -saved in a logtail error field, and a 4xx status code is returned. - -An invalid collection name returns `{"error": "invalid collection name"}` -along with a 403 status code. - -Clients are encouraged to: - -- POST as rapidly as possible (if not battery constrained). This minimizes - both the time necessary to see logs in a log viewer and the chance of - losing logs. -- Use HTTP/2 when streaming logs, as it does a much better job of - maintaining a TLS connection to minimize overhead for subsequent posts. - -A future version of logs service API will support sending requests with -`Content-Encoding: zstd`. - -## Retrieval - -### `GET /collections` — query the set of collections and instances - -Returns a JSON object listing all of the named collections. - -The caller can query-encode the following fields: - -- `collection-name` — limit the results to one collection - - ``` - { - "collections": { - "collection1.yourcompany.com": { - "instances": { - "" :{ - "first-seen": "timestamp", - "size": 4096 - }, - "" :{ - "first-seen": "timestamp", - "size": 512000, - "orphan": true, - } - } - } - } - } - ``` - -### `GET /c/` — query stored logs - -The caller can query-encode the following fields: - -- `instances` — zero or more log collection instances to limit results to -- `time-start` — the earliest log to include -- One of: - - `time-end` — the latest log to include - - `max-count` — maximum number of logs to return, allows paging - - `stream` — boolean that keeps the response dangling, streaming in - logs like `tail -f`. Incompatible with logtail-time-end. - -In **stream=false** mode, the response is a single JSON object: - - { - // TODO: header fields - "logs": [ {}, {}, ... ] - } - -In **stream=true** mode, the response begins with a JSON header object -similar to the storage format, and then is a sequence of JSON log -objects, `{...}`, one per line. The server continues to send these until -the client closes the connection. - -## Configuration - -For organizations with a small number of instances writing logs, the -Configuration API are best used by a trusted human operator, usually -through a GUI. Organizations with many instances will need to automate -the creation of tokens. - -### `POST /collections` — create or delete a collection - -The caller must set the `collection` property and `action=create` or -`action=delete`, either form encoded or JSON encoded. Its character set -is restricted to the mundane: [a-zA-Z0-9-_.]+ - -Collection names are a global space. Typically they are a domain name. - -### `POST /instances` — adopt an instance into a collection - -The caller must send the following properties, form encoded or JSON encoded: - -- `collection` — a valid FQDN ([a-zA-Z0-9-_.]+) -- `instances` an instance public ID encoded as hex - -The collection name must be claimed by a group the caller belongs to. -The pair (collection-name, instance-public-ID) may or may not already have -logs associated with it. - -On failure, an error message is returned with a 4xx or 5xx status code: - +# Tailscale Logs Service + +The Tailscale Logs Service defines a REST interface for configuring, storing, +retrieving, and processing log entries. + +# Overview + +HTTP requests are received at the service **base URL** +[https://log.tailscale.io](https://log.tailscale.io), and return JSON-encoded +responses using standard HTTP response codes. + +Authorization for the configuration and retrieval APIs is done with a secret +API key passed as the HTTP basic auth username. Secret keys are generated via +the web UI at base URL. An example of using basic auth with curl: + + curl -u : https://log.tailscale.io/collections + +In the future, an HTTP header will allow using MessagePack instead of JSON. + +## Collections + +Logs are organized into collections. Inside each collection is any number of +instances. + +A collection is a domain name. It is a grouping of related logs. As a +guideline, create one collection per product using subdomains of your +company's domain name. Collections must be registered with the logs service +before any attempt is made to store logs. + +## Instances + +Each collection is a set of instances. There is one instance per machine +writing logs. + +An instance has a name and a number. An instance has a **private** and +**public** ID. The private ID is a 32-byte random number encoded as hex. +The public ID is the SHA-256 hash of the private ID, encoded as hex. + +The private ID is used to write logs. The only copy of the private ID +should be on the machine sending logs. Ideally it is generated on the +machine. Logs can be written as soon as a private ID is generated. + +The public ID is used to read and adopt logs. It is designed to be sent +to a service that also holds a logs service API key. + +The tailscale logs service will store any logs for a short period of time. +To enable logs retention, the log can be **adopted** using the public ID +and a logs service API key. +Once this is done, logs will be retained long-term (for the configured +retention period). + +Unadopted instance logs are stored temporarily to help with debugging: +a misconfigured machine writing logs with a bad ID can be spotted by +reading the logs. +If a public ID is not adopted, storage is tightly capped and logs are +deleted after 12 hours. + +# APIs + +## Storage + +### `POST /c//` — send a log + +The body of the request is JSON. + +A **single message** is an object with properties: + +`{ }` + +The client may send any properties it wants in the JSON message, except +for the `logtail` property which has special meaning. Inside the logtail +object the client may only set the following properties: + +- `client_time` in the format of RFC3339: "2006-01-02T15:04:05.999999999Z07:00" + +A future version of the logs service API will also support: + +- `client_time_offset` a integer of nanoseconds since the client was reset +- `client_time_reset` a boolean if set to true resets the time offset counter + +On receipt by the server the `client_time_offset` is transformed into a +`client_time` based on the `server_time` when the first (or +client_time_reset) event was received. + +If any other properties are set in the logtail object they are moved into +the "error" field, the message is saved and a 4xx status code is returned. + +A **batch of messages** is a JSON array filled with single message objects: + +`[ { }, { }, ... ]` + +If any of the array entries are not objects, the content is converted +into a message with a `"logtail": { "error": ...}` property, saved, and +a 4xx status code is returned. + +Similarly any other request content not matching one of these formats is +saved in a logtail error field, and a 4xx status code is returned. + +An invalid collection name returns `{"error": "invalid collection name"}` +along with a 403 status code. + +Clients are encouraged to: + +- POST as rapidly as possible (if not battery constrained). This minimizes + both the time necessary to see logs in a log viewer and the chance of + losing logs. +- Use HTTP/2 when streaming logs, as it does a much better job of + maintaining a TLS connection to minimize overhead for subsequent posts. + +A future version of logs service API will support sending requests with +`Content-Encoding: zstd`. + +## Retrieval + +### `GET /collections` — query the set of collections and instances + +Returns a JSON object listing all of the named collections. + +The caller can query-encode the following fields: + +- `collection-name` — limit the results to one collection + + ``` + { + "collections": { + "collection1.yourcompany.com": { + "instances": { + "" :{ + "first-seen": "timestamp", + "size": 4096 + }, + "" :{ + "first-seen": "timestamp", + "size": 512000, + "orphan": true, + } + } + } + } + } + ``` + +### `GET /c/` — query stored logs + +The caller can query-encode the following fields: + +- `instances` — zero or more log collection instances to limit results to +- `time-start` — the earliest log to include +- One of: + - `time-end` — the latest log to include + - `max-count` — maximum number of logs to return, allows paging + - `stream` — boolean that keeps the response dangling, streaming in + logs like `tail -f`. Incompatible with logtail-time-end. + +In **stream=false** mode, the response is a single JSON object: + + { + // TODO: header fields + "logs": [ {}, {}, ... ] + } + +In **stream=true** mode, the response begins with a JSON header object +similar to the storage format, and then is a sequence of JSON log +objects, `{...}`, one per line. The server continues to send these until +the client closes the connection. + +## Configuration + +For organizations with a small number of instances writing logs, the +Configuration API are best used by a trusted human operator, usually +through a GUI. Organizations with many instances will need to automate +the creation of tokens. + +### `POST /collections` — create or delete a collection + +The caller must set the `collection` property and `action=create` or +`action=delete`, either form encoded or JSON encoded. Its character set +is restricted to the mundane: [a-zA-Z0-9-_.]+ + +Collection names are a global space. Typically they are a domain name. + +### `POST /instances` — adopt an instance into a collection + +The caller must send the following properties, form encoded or JSON encoded: + +- `collection` — a valid FQDN ([a-zA-Z0-9-_.]+) +- `instances` an instance public ID encoded as hex + +The collection name must be claimed by a group the caller belongs to. +The pair (collection-name, instance-public-ID) may or may not already have +logs associated with it. + +On failure, an error message is returned with a 4xx or 5xx status code: + `{"error": "what went wrong"}` \ No newline at end of file diff --git a/logtail/example/logreprocess/demo.sh b/logtail/example/logreprocess/demo.sh index eaec706a3..4ec819a67 100755 --- a/logtail/example/logreprocess/demo.sh +++ b/logtail/example/logreprocess/demo.sh @@ -1,86 +1,86 @@ -#!/bin/bash -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause - -# -# This shell script demonstrates writing logs from machines -# and then reprocessing those logs to amalgamate python tracebacks -# into a single log entry in a new collection. -# -# To run this demo, first install the example applications: -# -# go install tailscale.com/logtail/example/... -# -# Then generate a LOGTAIL_API_KEY and two test collections by visiting: -# -# https://log.tailscale.io -# -# Then set the three variables below. -trap 'rv=$?; [ "$rv" = 0 ] || echo "-- exiting with code $rv"; exit $rv' EXIT -set -e - -LOG_TEXT='server starting -config file loaded -answering queries -Traceback (most recent call last): - File "/Users/crawshaw/junk.py", line 6, in - main() - File "/Users/crawshaw/junk.py", line 4, in main - raise Exception("oops") -Exception: oops' - -die() { - echo "$0: $*" >&2 - exit 1 -} - -msg() { - echo "-- $*" >&2 -} - -if [ -z "$LOGTAIL_API_KEY" ]; then - die "LOGTAIL_API_KEY is not set" -fi - -if [ -z "$COLLECTION_IN" ]; then - die "COLLECTION_IN is not set" -fi - -if [ -z "$COLLECTION_OUT" ]; then - die "COLLECTION_OUT is not set" -fi - -# Private IDs are 32-bytes of random hex. -# Normally you'd keep the same private IDs from one run to the next, but -# this is just an example. -msg "Generating keys..." -privateid1=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) -privateid2=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) -privateid3=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) - -# Public IDs are the SHA-256 of the private ID. -publicid1=$(echo -n $privateid1 | xxd -r -p - | shasum -a 256 | sed 's/ -//') -publicid2=$(echo -n $privateid2 | xxd -r -p - | shasum -a 256 | sed 's/ -//') -publicid3=$(echo -n $privateid3 | xxd -r -p - | shasum -a 256 | sed 's/ -//') - -# Write the machine logs to the input collection. -# Notice that this doesn't require an API key. -msg "Producing new logs..." -echo "$LOG_TEXT" | logtail -c $COLLECTION_IN -k $privateid1 >/dev/null -echo "$LOG_TEXT" | logtail -c $COLLECTION_IN -k $privateid2 >/dev/null - -# Adopt the logs, so they will be kept and are readable. -msg "Adopting logs..." -logadopt -p "$LOGTAIL_API_KEY" -c "$COLLECTION_IN" -m $publicid1 -logadopt -p "$LOGTAIL_API_KEY" -c "$COLLECTION_IN" -m $publicid2 - -# Reprocess the logs, amalgamating python tracebacks. -# -# We'll take that reprocessed output and write it to a separate collection, -# again via logtail. -# -# Time out quickly because all our "interesting" logs (generated -# above) have already been processed. -msg "Reprocessing logs..." -logreprocess -t 3s -c "$COLLECTION_IN" -p "$LOGTAIL_API_KEY" 2>&1 | - logtail -c "$COLLECTION_OUT" -k $privateid3 +#!/bin/bash +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +# +# This shell script demonstrates writing logs from machines +# and then reprocessing those logs to amalgamate python tracebacks +# into a single log entry in a new collection. +# +# To run this demo, first install the example applications: +# +# go install tailscale.com/logtail/example/... +# +# Then generate a LOGTAIL_API_KEY and two test collections by visiting: +# +# https://log.tailscale.io +# +# Then set the three variables below. +trap 'rv=$?; [ "$rv" = 0 ] || echo "-- exiting with code $rv"; exit $rv' EXIT +set -e + +LOG_TEXT='server starting +config file loaded +answering queries +Traceback (most recent call last): + File "/Users/crawshaw/junk.py", line 6, in + main() + File "/Users/crawshaw/junk.py", line 4, in main + raise Exception("oops") +Exception: oops' + +die() { + echo "$0: $*" >&2 + exit 1 +} + +msg() { + echo "-- $*" >&2 +} + +if [ -z "$LOGTAIL_API_KEY" ]; then + die "LOGTAIL_API_KEY is not set" +fi + +if [ -z "$COLLECTION_IN" ]; then + die "COLLECTION_IN is not set" +fi + +if [ -z "$COLLECTION_OUT" ]; then + die "COLLECTION_OUT is not set" +fi + +# Private IDs are 32-bytes of random hex. +# Normally you'd keep the same private IDs from one run to the next, but +# this is just an example. +msg "Generating keys..." +privateid1=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) +privateid2=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) +privateid3=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) + +# Public IDs are the SHA-256 of the private ID. +publicid1=$(echo -n $privateid1 | xxd -r -p - | shasum -a 256 | sed 's/ -//') +publicid2=$(echo -n $privateid2 | xxd -r -p - | shasum -a 256 | sed 's/ -//') +publicid3=$(echo -n $privateid3 | xxd -r -p - | shasum -a 256 | sed 's/ -//') + +# Write the machine logs to the input collection. +# Notice that this doesn't require an API key. +msg "Producing new logs..." +echo "$LOG_TEXT" | logtail -c $COLLECTION_IN -k $privateid1 >/dev/null +echo "$LOG_TEXT" | logtail -c $COLLECTION_IN -k $privateid2 >/dev/null + +# Adopt the logs, so they will be kept and are readable. +msg "Adopting logs..." +logadopt -p "$LOGTAIL_API_KEY" -c "$COLLECTION_IN" -m $publicid1 +logadopt -p "$LOGTAIL_API_KEY" -c "$COLLECTION_IN" -m $publicid2 + +# Reprocess the logs, amalgamating python tracebacks. +# +# We'll take that reprocessed output and write it to a separate collection, +# again via logtail. +# +# Time out quickly because all our "interesting" logs (generated +# above) have already been processed. +msg "Reprocessing logs..." +logreprocess -t 3s -c "$COLLECTION_IN" -p "$LOGTAIL_API_KEY" 2>&1 | + logtail -c "$COLLECTION_OUT" -k $privateid3 diff --git a/logtail/example/logreprocess/logreprocess.go b/logtail/example/logreprocess/logreprocess.go index e88d5b485..5dbf76578 100644 --- a/logtail/example/logreprocess/logreprocess.go +++ b/logtail/example/logreprocess/logreprocess.go @@ -1,115 +1,115 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The logreprocess program tails a log and reprocesses it. -package main - -import ( - "bufio" - "encoding/json" - "flag" - "io" - "log" - "net/http" - "os" - "strings" - "time" - - "tailscale.com/types/logid" -) - -func main() { - collection := flag.String("c", "", "logtail collection name to read") - apiKey := flag.String("p", "", "logtail API key") - timeout := flag.Duration("t", 0, "timeout after which logreprocess quits") - flag.Parse() - if len(flag.Args()) != 0 { - flag.Usage() - os.Exit(1) - } - log.SetFlags(0) - - if *timeout != 0 { - go func() { - <-time.After(*timeout) - log.Printf("logreprocess: timeout reached, quitting") - os.Exit(1) - }() - } - - req, err := http.NewRequest("GET", "https://log.tailscale.io/c/"+*collection+"?stream=true", nil) - if err != nil { - log.Fatal(err) - } - req.SetBasicAuth(*apiKey, "") - resp, err := http.DefaultClient.Do(req) - if err != nil { - log.Fatal(err) - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - b, err := io.ReadAll(resp.Body) - if err != nil { - log.Fatalf("logreprocess: read error %d: %v", resp.StatusCode, err) - } - log.Fatalf("logreprocess: read error %d: %s", resp.StatusCode, string(b)) - } - - tracebackCache := make(map[logid.PublicID]*ProcessedMsg) - - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - var msg Msg - if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil { - log.Fatalf("logreprocess of %q: %v", string(scanner.Bytes()), err) - } - var pMsg *ProcessedMsg - if pMsg = tracebackCache[msg.Logtail.Instance]; pMsg != nil { - pMsg.Text += "\n" + msg.Text - if strings.HasPrefix(msg.Text, "Exception: ") { - delete(tracebackCache, msg.Logtail.Instance) - } else { - continue // write later - } - } else { - pMsg = &ProcessedMsg{ - OrigInstance: msg.Logtail.Instance, - Text: msg.Text, - } - pMsg.Logtail.ClientTime = msg.Logtail.ClientTime - } - - if strings.HasPrefix(msg.Text, "Traceback (most recent call last):") { - tracebackCache[msg.Logtail.Instance] = pMsg - continue // write later - } - - b, err := json.Marshal(pMsg) - if err != nil { - log.Fatal(err) - } - log.Printf("%s", b) - } - if err := scanner.Err(); err != nil { - log.Fatal(err) - } -} - -type Msg struct { - Logtail struct { - Instance logid.PublicID `json:"instance"` - ClientTime time.Time `json:"client_time"` - } `json:"logtail"` - - Text string `json:"text"` -} - -type ProcessedMsg struct { - Logtail struct { - ClientTime time.Time `json:"client_time"` - } `json:"logtail"` - - OrigInstance logid.PublicID `json:"orig_instance"` - Text string `json:"text"` -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The logreprocess program tails a log and reprocesses it. +package main + +import ( + "bufio" + "encoding/json" + "flag" + "io" + "log" + "net/http" + "os" + "strings" + "time" + + "tailscale.com/types/logid" +) + +func main() { + collection := flag.String("c", "", "logtail collection name to read") + apiKey := flag.String("p", "", "logtail API key") + timeout := flag.Duration("t", 0, "timeout after which logreprocess quits") + flag.Parse() + if len(flag.Args()) != 0 { + flag.Usage() + os.Exit(1) + } + log.SetFlags(0) + + if *timeout != 0 { + go func() { + <-time.After(*timeout) + log.Printf("logreprocess: timeout reached, quitting") + os.Exit(1) + }() + } + + req, err := http.NewRequest("GET", "https://log.tailscale.io/c/"+*collection+"?stream=true", nil) + if err != nil { + log.Fatal(err) + } + req.SetBasicAuth(*apiKey, "") + resp, err := http.DefaultClient.Do(req) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + b, err := io.ReadAll(resp.Body) + if err != nil { + log.Fatalf("logreprocess: read error %d: %v", resp.StatusCode, err) + } + log.Fatalf("logreprocess: read error %d: %s", resp.StatusCode, string(b)) + } + + tracebackCache := make(map[logid.PublicID]*ProcessedMsg) + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + var msg Msg + if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil { + log.Fatalf("logreprocess of %q: %v", string(scanner.Bytes()), err) + } + var pMsg *ProcessedMsg + if pMsg = tracebackCache[msg.Logtail.Instance]; pMsg != nil { + pMsg.Text += "\n" + msg.Text + if strings.HasPrefix(msg.Text, "Exception: ") { + delete(tracebackCache, msg.Logtail.Instance) + } else { + continue // write later + } + } else { + pMsg = &ProcessedMsg{ + OrigInstance: msg.Logtail.Instance, + Text: msg.Text, + } + pMsg.Logtail.ClientTime = msg.Logtail.ClientTime + } + + if strings.HasPrefix(msg.Text, "Traceback (most recent call last):") { + tracebackCache[msg.Logtail.Instance] = pMsg + continue // write later + } + + b, err := json.Marshal(pMsg) + if err != nil { + log.Fatal(err) + } + log.Printf("%s", b) + } + if err := scanner.Err(); err != nil { + log.Fatal(err) + } +} + +type Msg struct { + Logtail struct { + Instance logid.PublicID `json:"instance"` + ClientTime time.Time `json:"client_time"` + } `json:"logtail"` + + Text string `json:"text"` +} + +type ProcessedMsg struct { + Logtail struct { + ClientTime time.Time `json:"client_time"` + } `json:"logtail"` + + OrigInstance logid.PublicID `json:"orig_instance"` + Text string `json:"text"` +} diff --git a/logtail/example/logtail/logtail.go b/logtail/example/logtail/logtail.go index e77705513..0c9e44258 100644 --- a/logtail/example/logtail/logtail.go +++ b/logtail/example/logtail/logtail.go @@ -1,46 +1,46 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The logtail program logs stdin. -package main - -import ( - "bufio" - "flag" - "io" - "log" - "os" - - "tailscale.com/logtail" - "tailscale.com/types/logid" -) - -func main() { - collection := flag.String("c", "", "logtail collection name") - privateID := flag.String("k", "", "machine private identifier, 32-bytes in hex") - flag.Parse() - if len(flag.Args()) != 0 { - flag.Usage() - os.Exit(1) - } - - log.SetFlags(0) - - var id logid.PrivateID - if err := id.UnmarshalText([]byte(*privateID)); err != nil { - log.Fatalf("logtail: bad -privateid: %v", err) - } - - logger := logtail.NewLogger(logtail.Config{ - Collection: *collection, - PrivateID: id, - }, log.Printf) - log.SetOutput(io.MultiWriter(logger, os.Stdout)) - defer logger.Flush() - defer log.Printf("logtail exited") - - scanner := bufio.NewScanner(os.Stdin) - for scanner.Scan() { - log.Println(scanner.Text()) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The logtail program logs stdin. +package main + +import ( + "bufio" + "flag" + "io" + "log" + "os" + + "tailscale.com/logtail" + "tailscale.com/types/logid" +) + +func main() { + collection := flag.String("c", "", "logtail collection name") + privateID := flag.String("k", "", "machine private identifier, 32-bytes in hex") + flag.Parse() + if len(flag.Args()) != 0 { + flag.Usage() + os.Exit(1) + } + + log.SetFlags(0) + + var id logid.PrivateID + if err := id.UnmarshalText([]byte(*privateID)); err != nil { + log.Fatalf("logtail: bad -privateid: %v", err) + } + + logger := logtail.NewLogger(logtail.Config{ + Collection: *collection, + PrivateID: id, + }, log.Printf) + log.SetOutput(io.MultiWriter(logger, os.Stdout)) + defer logger.Flush() + defer log.Printf("logtail exited") + + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + log.Println(scanner.Text()) + } +} diff --git a/logtail/filch/filch.go b/logtail/filch/filch.go index 886fe239c..d00206dd5 100644 --- a/logtail/filch/filch.go +++ b/logtail/filch/filch.go @@ -1,284 +1,284 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package filch is a file system queue that pilfers your stderr. -// (A FILe CHannel that filches.) -package filch - -import ( - "bufio" - "bytes" - "fmt" - "io" - "os" - "sync" -) - -var stderrFD = 2 // a variable for testing - -const defaultMaxFileSize = 50 << 20 - -type Options struct { - ReplaceStderr bool // dup over fd 2 so everything written to stderr comes here - MaxFileSize int -} - -// A Filch uses two alternating files as a simplistic ring buffer. -type Filch struct { - OrigStderr *os.File - - mu sync.Mutex - cur *os.File - alt *os.File - altscan *bufio.Scanner - recovered int64 - - maxFileSize int64 - writeCounter int - - // buf is an initial buffer for altscan. - // As of August 2021, 99.96% of all log lines - // are below 4096 bytes in length. - // Since this cutoff is arbitrary, instead of using 4096, - // we subtract off the size of the rest of the struct - // so that the whole struct takes 4096 bytes - // (less on 32 bit platforms). - // This reduces allocation waste. - buf [4096 - 64]byte -} - -// TryReadline implements the logtail.Buffer interface. -func (f *Filch) TryReadLine() ([]byte, error) { - f.mu.Lock() - defer f.mu.Unlock() - - if f.altscan != nil { - if b, err := f.scan(); b != nil || err != nil { - return b, err - } - } - - f.cur, f.alt = f.alt, f.cur - if f.OrigStderr != nil { - if err := dup2Stderr(f.cur); err != nil { - return nil, err - } - } - if _, err := f.alt.Seek(0, io.SeekStart); err != nil { - return nil, err - } - f.altscan = bufio.NewScanner(f.alt) - f.altscan.Buffer(f.buf[:], bufio.MaxScanTokenSize) - f.altscan.Split(splitLines) - return f.scan() -} - -func (f *Filch) scan() ([]byte, error) { - if f.altscan.Scan() { - return f.altscan.Bytes(), nil - } - err := f.altscan.Err() - err2 := f.alt.Truncate(0) - _, err3 := f.alt.Seek(0, io.SeekStart) - f.altscan = nil - if err != nil { - return nil, err - } - if err2 != nil { - return nil, err2 - } - if err3 != nil { - return nil, err3 - } - return nil, nil -} - -// Write implements the logtail.Buffer interface. -func (f *Filch) Write(b []byte) (int, error) { - f.mu.Lock() - defer f.mu.Unlock() - if f.writeCounter == 100 { - // Check the file size every 100 writes. - f.writeCounter = 0 - fi, err := f.cur.Stat() - if err != nil { - return 0, err - } - if fi.Size() >= f.maxFileSize { - // This most likely means we are not draining. - // To limit the amount of space we use, throw away the old logs. - if err := moveContents(f.alt, f.cur); err != nil { - return 0, err - } - } - } - f.writeCounter++ - - if len(b) == 0 || b[len(b)-1] != '\n' { - bnl := make([]byte, len(b)+1) - copy(bnl, b) - bnl[len(bnl)-1] = '\n' - return f.cur.Write(bnl) - } - return f.cur.Write(b) -} - -// Close closes the Filch, releasing all os resources. -func (f *Filch) Close() (err error) { - f.mu.Lock() - defer f.mu.Unlock() - - if f.OrigStderr != nil { - if err2 := unsaveStderr(f.OrigStderr); err == nil { - err = err2 - } - f.OrigStderr = nil - } - - if err2 := f.cur.Close(); err == nil { - err = err2 - } - if err2 := f.alt.Close(); err == nil { - err = err2 - } - - return err -} - -// New creates a new filch around two log files, each starting with filePrefix. -func New(filePrefix string, opts Options) (f *Filch, err error) { - var f1, f2 *os.File - defer func() { - if err != nil { - if f1 != nil { - f1.Close() - } - if f2 != nil { - f2.Close() - } - err = fmt.Errorf("filch: %s", err) - } - }() - - path1 := filePrefix + ".log1.txt" - path2 := filePrefix + ".log2.txt" - - f1, err = os.OpenFile(path1, os.O_CREATE|os.O_RDWR, 0600) - if err != nil { - return nil, err - } - f2, err = os.OpenFile(path2, os.O_CREATE|os.O_RDWR, 0600) - if err != nil { - return nil, err - } - - fi1, err := f1.Stat() - if err != nil { - return nil, err - } - fi2, err := f2.Stat() - if err != nil { - return nil, err - } - - mfs := defaultMaxFileSize - if opts.MaxFileSize > 0 { - mfs = opts.MaxFileSize - } - f = &Filch{ - OrigStderr: os.Stderr, // temporary, for past logs recovery - maxFileSize: int64(mfs), - } - - // Neither, either, or both files may exist and contain logs from - // the last time the process ran. The three cases are: - // - // - neither: all logs were read out and files were truncated - // - either: logs were being written into one of the files - // - both: the files were swapped and were starting to be - // read out, while new logs streamed into the other - // file, but the read out did not complete - if n := fi1.Size() + fi2.Size(); n > 0 { - f.recovered = n - } - switch { - case fi1.Size() > 0 && fi2.Size() == 0: - f.cur, f.alt = f2, f1 - case fi2.Size() > 0 && fi1.Size() == 0: - f.cur, f.alt = f1, f2 - case fi1.Size() > 0 && fi2.Size() > 0: // both - // We need to pick one of the files to be the elder, - // which we do using the mtime. - var older, newer *os.File - if fi1.ModTime().Before(fi2.ModTime()) { - older, newer = f1, f2 - } else { - older, newer = f2, f1 - } - if err := moveContents(older, newer); err != nil { - fmt.Fprintf(f.OrigStderr, "filch: recover move failed: %v\n", err) - fmt.Fprintf(older, "filch: recover move failed: %v\n", err) - } - f.cur, f.alt = newer, older - default: - f.cur, f.alt = f1, f2 // does not matter - } - if f.recovered > 0 { - f.altscan = bufio.NewScanner(f.alt) - f.altscan.Buffer(f.buf[:], bufio.MaxScanTokenSize) - f.altscan.Split(splitLines) - } - - f.OrigStderr = nil - if opts.ReplaceStderr { - f.OrigStderr, err = saveStderr() - if err != nil { - return nil, err - } - if err := dup2Stderr(f.cur); err != nil { - return nil, err - } - } - - return f, nil -} - -func moveContents(dst, src *os.File) (err error) { - defer func() { - _, err2 := src.Seek(0, io.SeekStart) - err3 := src.Truncate(0) - _, err4 := dst.Seek(0, io.SeekStart) - if err == nil { - err = err2 - } - if err == nil { - err = err3 - } - if err == nil { - err = err4 - } - }() - if _, err := src.Seek(0, io.SeekStart); err != nil { - return err - } - if _, err := dst.Seek(0, io.SeekStart); err != nil { - return err - } - if _, err := io.Copy(dst, src); err != nil { - return err - } - return nil -} - -func splitLines(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := bytes.IndexByte(data, '\n'); i >= 0 { - return i + 1, data[0 : i+1], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package filch is a file system queue that pilfers your stderr. +// (A FILe CHannel that filches.) +package filch + +import ( + "bufio" + "bytes" + "fmt" + "io" + "os" + "sync" +) + +var stderrFD = 2 // a variable for testing + +const defaultMaxFileSize = 50 << 20 + +type Options struct { + ReplaceStderr bool // dup over fd 2 so everything written to stderr comes here + MaxFileSize int +} + +// A Filch uses two alternating files as a simplistic ring buffer. +type Filch struct { + OrigStderr *os.File + + mu sync.Mutex + cur *os.File + alt *os.File + altscan *bufio.Scanner + recovered int64 + + maxFileSize int64 + writeCounter int + + // buf is an initial buffer for altscan. + // As of August 2021, 99.96% of all log lines + // are below 4096 bytes in length. + // Since this cutoff is arbitrary, instead of using 4096, + // we subtract off the size of the rest of the struct + // so that the whole struct takes 4096 bytes + // (less on 32 bit platforms). + // This reduces allocation waste. + buf [4096 - 64]byte +} + +// TryReadline implements the logtail.Buffer interface. +func (f *Filch) TryReadLine() ([]byte, error) { + f.mu.Lock() + defer f.mu.Unlock() + + if f.altscan != nil { + if b, err := f.scan(); b != nil || err != nil { + return b, err + } + } + + f.cur, f.alt = f.alt, f.cur + if f.OrigStderr != nil { + if err := dup2Stderr(f.cur); err != nil { + return nil, err + } + } + if _, err := f.alt.Seek(0, io.SeekStart); err != nil { + return nil, err + } + f.altscan = bufio.NewScanner(f.alt) + f.altscan.Buffer(f.buf[:], bufio.MaxScanTokenSize) + f.altscan.Split(splitLines) + return f.scan() +} + +func (f *Filch) scan() ([]byte, error) { + if f.altscan.Scan() { + return f.altscan.Bytes(), nil + } + err := f.altscan.Err() + err2 := f.alt.Truncate(0) + _, err3 := f.alt.Seek(0, io.SeekStart) + f.altscan = nil + if err != nil { + return nil, err + } + if err2 != nil { + return nil, err2 + } + if err3 != nil { + return nil, err3 + } + return nil, nil +} + +// Write implements the logtail.Buffer interface. +func (f *Filch) Write(b []byte) (int, error) { + f.mu.Lock() + defer f.mu.Unlock() + if f.writeCounter == 100 { + // Check the file size every 100 writes. + f.writeCounter = 0 + fi, err := f.cur.Stat() + if err != nil { + return 0, err + } + if fi.Size() >= f.maxFileSize { + // This most likely means we are not draining. + // To limit the amount of space we use, throw away the old logs. + if err := moveContents(f.alt, f.cur); err != nil { + return 0, err + } + } + } + f.writeCounter++ + + if len(b) == 0 || b[len(b)-1] != '\n' { + bnl := make([]byte, len(b)+1) + copy(bnl, b) + bnl[len(bnl)-1] = '\n' + return f.cur.Write(bnl) + } + return f.cur.Write(b) +} + +// Close closes the Filch, releasing all os resources. +func (f *Filch) Close() (err error) { + f.mu.Lock() + defer f.mu.Unlock() + + if f.OrigStderr != nil { + if err2 := unsaveStderr(f.OrigStderr); err == nil { + err = err2 + } + f.OrigStderr = nil + } + + if err2 := f.cur.Close(); err == nil { + err = err2 + } + if err2 := f.alt.Close(); err == nil { + err = err2 + } + + return err +} + +// New creates a new filch around two log files, each starting with filePrefix. +func New(filePrefix string, opts Options) (f *Filch, err error) { + var f1, f2 *os.File + defer func() { + if err != nil { + if f1 != nil { + f1.Close() + } + if f2 != nil { + f2.Close() + } + err = fmt.Errorf("filch: %s", err) + } + }() + + path1 := filePrefix + ".log1.txt" + path2 := filePrefix + ".log2.txt" + + f1, err = os.OpenFile(path1, os.O_CREATE|os.O_RDWR, 0600) + if err != nil { + return nil, err + } + f2, err = os.OpenFile(path2, os.O_CREATE|os.O_RDWR, 0600) + if err != nil { + return nil, err + } + + fi1, err := f1.Stat() + if err != nil { + return nil, err + } + fi2, err := f2.Stat() + if err != nil { + return nil, err + } + + mfs := defaultMaxFileSize + if opts.MaxFileSize > 0 { + mfs = opts.MaxFileSize + } + f = &Filch{ + OrigStderr: os.Stderr, // temporary, for past logs recovery + maxFileSize: int64(mfs), + } + + // Neither, either, or both files may exist and contain logs from + // the last time the process ran. The three cases are: + // + // - neither: all logs were read out and files were truncated + // - either: logs were being written into one of the files + // - both: the files were swapped and were starting to be + // read out, while new logs streamed into the other + // file, but the read out did not complete + if n := fi1.Size() + fi2.Size(); n > 0 { + f.recovered = n + } + switch { + case fi1.Size() > 0 && fi2.Size() == 0: + f.cur, f.alt = f2, f1 + case fi2.Size() > 0 && fi1.Size() == 0: + f.cur, f.alt = f1, f2 + case fi1.Size() > 0 && fi2.Size() > 0: // both + // We need to pick one of the files to be the elder, + // which we do using the mtime. + var older, newer *os.File + if fi1.ModTime().Before(fi2.ModTime()) { + older, newer = f1, f2 + } else { + older, newer = f2, f1 + } + if err := moveContents(older, newer); err != nil { + fmt.Fprintf(f.OrigStderr, "filch: recover move failed: %v\n", err) + fmt.Fprintf(older, "filch: recover move failed: %v\n", err) + } + f.cur, f.alt = newer, older + default: + f.cur, f.alt = f1, f2 // does not matter + } + if f.recovered > 0 { + f.altscan = bufio.NewScanner(f.alt) + f.altscan.Buffer(f.buf[:], bufio.MaxScanTokenSize) + f.altscan.Split(splitLines) + } + + f.OrigStderr = nil + if opts.ReplaceStderr { + f.OrigStderr, err = saveStderr() + if err != nil { + return nil, err + } + if err := dup2Stderr(f.cur); err != nil { + return nil, err + } + } + + return f, nil +} + +func moveContents(dst, src *os.File) (err error) { + defer func() { + _, err2 := src.Seek(0, io.SeekStart) + err3 := src.Truncate(0) + _, err4 := dst.Seek(0, io.SeekStart) + if err == nil { + err = err2 + } + if err == nil { + err = err3 + } + if err == nil { + err = err4 + } + }() + if _, err := src.Seek(0, io.SeekStart); err != nil { + return err + } + if _, err := dst.Seek(0, io.SeekStart); err != nil { + return err + } + if _, err := io.Copy(dst, src); err != nil { + return err + } + return nil +} + +func splitLines(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.IndexByte(data, '\n'); i >= 0 { + return i + 1, data[0 : i+1], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil +} diff --git a/logtail/filch/filch_stub.go b/logtail/filch/filch_stub.go index fe718d150..3bb82b190 100644 --- a/logtail/filch/filch_stub.go +++ b/logtail/filch/filch_stub.go @@ -1,23 +1,23 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build wasm || plan9 || tamago - -package filch - -import ( - "os" -) - -func saveStderr() (*os.File, error) { - return os.Stderr, nil -} - -func unsaveStderr(f *os.File) error { - os.Stderr = f - return nil -} - -func dup2Stderr(f *os.File) error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build wasm || plan9 || tamago + +package filch + +import ( + "os" +) + +func saveStderr() (*os.File, error) { + return os.Stderr, nil +} + +func unsaveStderr(f *os.File) error { + os.Stderr = f + return nil +} + +func dup2Stderr(f *os.File) error { + return nil +} diff --git a/logtail/filch/filch_unix.go b/logtail/filch/filch_unix.go index b06ef6afd..2eae70ace 100644 --- a/logtail/filch/filch_unix.go +++ b/logtail/filch/filch_unix.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !wasm && !plan9 && !tamago - -package filch - -import ( - "os" - - "golang.org/x/sys/unix" -) - -func saveStderr() (*os.File, error) { - fd, err := unix.Dup(stderrFD) - if err != nil { - return nil, err - } - return os.NewFile(uintptr(fd), "stderr"), nil -} - -func unsaveStderr(f *os.File) error { - err := dup2Stderr(f) - f.Close() - return err -} - -func dup2Stderr(f *os.File) error { - return unix.Dup2(int(f.Fd()), stderrFD) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !wasm && !plan9 && !tamago + +package filch + +import ( + "os" + + "golang.org/x/sys/unix" +) + +func saveStderr() (*os.File, error) { + fd, err := unix.Dup(stderrFD) + if err != nil { + return nil, err + } + return os.NewFile(uintptr(fd), "stderr"), nil +} + +func unsaveStderr(f *os.File) error { + err := dup2Stderr(f) + f.Close() + return err +} + +func dup2Stderr(f *os.File) error { + return unix.Dup2(int(f.Fd()), stderrFD) +} diff --git a/logtail/filch/filch_windows.go b/logtail/filch/filch_windows.go index 1419d6606..d60514bf0 100644 --- a/logtail/filch/filch_windows.go +++ b/logtail/filch/filch_windows.go @@ -1,43 +1,43 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package filch - -import ( - "fmt" - "os" - "syscall" -) - -var kernel32 = syscall.MustLoadDLL("kernel32.dll") -var procSetStdHandle = kernel32.MustFindProc("SetStdHandle") - -func setStdHandle(stdHandle int32, handle syscall.Handle) error { - r, _, e := syscall.Syscall(procSetStdHandle.Addr(), 2, uintptr(stdHandle), uintptr(handle), 0) - if r == 0 { - if e != 0 { - return error(e) - } - return syscall.EINVAL - } - return nil -} - -func saveStderr() (*os.File, error) { - return os.Stderr, nil -} - -func unsaveStderr(f *os.File) error { - os.Stderr = f - return nil -} - -func dup2Stderr(f *os.File) error { - fd := int(f.Fd()) - err := setStdHandle(syscall.STD_ERROR_HANDLE, syscall.Handle(fd)) - if err != nil { - return fmt.Errorf("dup2Stderr: %w", err) - } - os.Stderr = f - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package filch + +import ( + "fmt" + "os" + "syscall" +) + +var kernel32 = syscall.MustLoadDLL("kernel32.dll") +var procSetStdHandle = kernel32.MustFindProc("SetStdHandle") + +func setStdHandle(stdHandle int32, handle syscall.Handle) error { + r, _, e := syscall.Syscall(procSetStdHandle.Addr(), 2, uintptr(stdHandle), uintptr(handle), 0) + if r == 0 { + if e != 0 { + return error(e) + } + return syscall.EINVAL + } + return nil +} + +func saveStderr() (*os.File, error) { + return os.Stderr, nil +} + +func unsaveStderr(f *os.File) error { + os.Stderr = f + return nil +} + +func dup2Stderr(f *os.File) error { + fd := int(f.Fd()) + err := setStdHandle(syscall.STD_ERROR_HANDLE, syscall.Handle(fd)) + if err != nil { + return fmt.Errorf("dup2Stderr: %w", err) + } + os.Stderr = f + return nil +} diff --git a/metrics/fds_linux.go b/metrics/fds_linux.go index 66ebb419d..34740c2bb 100644 --- a/metrics/fds_linux.go +++ b/metrics/fds_linux.go @@ -1,41 +1,41 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package metrics - -import ( - "io/fs" - "sync" - - "go4.org/mem" - "tailscale.com/util/dirwalk" -) - -// counter is a reusable counter for counting file descriptors. -type counter struct { - n int - - // cb is the (*counter).count method value. Creating it allocates, - // so we have to save it away and use a sync.Pool to keep currentFDs - // amortized alloc-free. - cb func(name mem.RO, de fs.DirEntry) error -} - -var counterPool = &sync.Pool{New: func() any { - c := new(counter) - c.cb = c.count - return c -}} - -func (c *counter) count(name mem.RO, de fs.DirEntry) error { - c.n++ - return nil -} - -func currentFDs() int { - c := counterPool.Get().(*counter) - defer counterPool.Put(c) - c.n = 0 - dirwalk.WalkShallow(mem.S("/proc/self/fd"), c.cb) - return c.n -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package metrics + +import ( + "io/fs" + "sync" + + "go4.org/mem" + "tailscale.com/util/dirwalk" +) + +// counter is a reusable counter for counting file descriptors. +type counter struct { + n int + + // cb is the (*counter).count method value. Creating it allocates, + // so we have to save it away and use a sync.Pool to keep currentFDs + // amortized alloc-free. + cb func(name mem.RO, de fs.DirEntry) error +} + +var counterPool = &sync.Pool{New: func() any { + c := new(counter) + c.cb = c.count + return c +}} + +func (c *counter) count(name mem.RO, de fs.DirEntry) error { + c.n++ + return nil +} + +func currentFDs() int { + c := counterPool.Get().(*counter) + defer counterPool.Put(c) + c.n = 0 + dirwalk.WalkShallow(mem.S("/proc/self/fd"), c.cb) + return c.n +} diff --git a/metrics/fds_notlinux.go b/metrics/fds_notlinux.go index 5a59d4de9..2dae97cad 100644 --- a/metrics/fds_notlinux.go +++ b/metrics/fds_notlinux.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux - -package metrics - -func currentFDs() int { return 0 } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package metrics + +func currentFDs() int { return 0 } diff --git a/metrics/metrics.go b/metrics/metrics.go index 0f67ffa30..a07ddccae 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -1,163 +1,163 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package metrics contains expvar & Prometheus types and code used by -// Tailscale for monitoring. -package metrics - -import ( - "expvar" - "fmt" - "io" - "slices" - "strings" -) - -// Set is a string-to-Var map variable that satisfies the expvar.Var -// interface. -// -// Semantically, this is mapped by tsweb's Prometheus exporter as a -// collection of unrelated variables exported with a common prefix. -// -// This lets us have tsweb recognize *expvar.Map for different -// purposes in the future. (Or perhaps all uses of expvar.Map will -// require explicit types like this one, declaring how we want tsweb -// to export it to Prometheus.) -type Set struct { - expvar.Map -} - -// LabelMap is a string-to-Var map variable that satisfies the -// expvar.Var interface. -// -// Semantically, this is mapped by tsweb's Prometheus exporter as a -// collection of variables with the same name, with a varying label -// value. Use this to export things that are intuitively breakdowns -// into different buckets. -type LabelMap struct { - Label string - expvar.Map -} - -// SetInt64 sets the *Int value stored under the given map key. -func (m *LabelMap) SetInt64(key string, v int64) { - m.Get(key).Set(v) -} - -// Get returns a direct pointer to the expvar.Int for key, creating it -// if necessary. -func (m *LabelMap) Get(key string) *expvar.Int { - m.Add(key, 0) - return m.Map.Get(key).(*expvar.Int) -} - -// GetIncrFunc returns a function that increments the expvar.Int named by key. -// -// Most callers should not need this; it exists to satisfy an -// interface elsewhere. -func (m *LabelMap) GetIncrFunc(key string) func(delta int64) { - return m.Get(key).Add -} - -// GetFloat returns a direct pointer to the expvar.Float for key, creating it -// if necessary. -func (m *LabelMap) GetFloat(key string) *expvar.Float { - m.AddFloat(key, 0.0) - return m.Map.Get(key).(*expvar.Float) -} - -// CurrentFDs reports how many file descriptors are currently open. -// -// It only works on Linux. It returns zero otherwise. -func CurrentFDs() int { - return currentFDs() -} - -// Histogram is a histogram of values. -// It should be created with NewHistogram. -type Histogram struct { - // buckets is a list of bucket boundaries, in increasing order. - buckets []float64 - - // bucketStrings is a list of the same buckets, but as strings. - // This are allocated once at creation time by NewHistogram. - bucketStrings []string - - bucketVars []expvar.Int - sum expvar.Float - count expvar.Int -} - -// NewHistogram returns a new histogram that reports to the given -// expvar map under the given name. -// -// The buckets are the boundaries of the histogram buckets, in -// increasing order. The last bucket is +Inf. -func NewHistogram(buckets []float64) *Histogram { - if !slices.IsSorted(buckets) { - panic("buckets must be sorted") - } - labels := make([]string, len(buckets)) - for i, b := range buckets { - labels[i] = fmt.Sprintf("%v", b) - } - h := &Histogram{ - buckets: buckets, - bucketStrings: labels, - bucketVars: make([]expvar.Int, len(buckets)), - } - return h -} - -// Observe records a new observation in the histogram. -func (h *Histogram) Observe(v float64) { - h.sum.Add(v) - h.count.Add(1) - for i, b := range h.buckets { - if v <= b { - h.bucketVars[i].Add(1) - } - } -} - -// String returns a JSON representation of the histogram. -// This is used to satisfy the expvar.Var interface. -func (h *Histogram) String() string { - var b strings.Builder - fmt.Fprintf(&b, "{") - first := true - h.Do(func(kv expvar.KeyValue) { - if !first { - fmt.Fprintf(&b, ",") - } - fmt.Fprintf(&b, "%q: ", kv.Key) - if kv.Value != nil { - fmt.Fprintf(&b, "%v", kv.Value) - } else { - fmt.Fprint(&b, "null") - } - first = false - }) - fmt.Fprintf(&b, ",\"sum\": %v", &h.sum) - fmt.Fprintf(&b, ",\"count\": %v", &h.count) - fmt.Fprintf(&b, "}") - return b.String() -} - -// Do calls f for each bucket in the histogram. -func (h *Histogram) Do(f func(expvar.KeyValue)) { - for i := range h.bucketVars { - f(expvar.KeyValue{Key: h.bucketStrings[i], Value: &h.bucketVars[i]}) - } - f(expvar.KeyValue{Key: "+Inf", Value: &h.count}) -} - -// PromExport writes the histogram to w in Prometheus exposition format. -func (h *Histogram) PromExport(w io.Writer, name string) { - fmt.Fprintf(w, "# TYPE %s histogram\n", name) - h.Do(func(kv expvar.KeyValue) { - fmt.Fprintf(w, "%s_bucket{le=%q} %v\n", name, kv.Key, kv.Value) - }) - fmt.Fprintf(w, "%s_sum %v\n", name, &h.sum) - fmt.Fprintf(w, "%s_count %v\n", name, &h.count) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package metrics contains expvar & Prometheus types and code used by +// Tailscale for monitoring. +package metrics + +import ( + "expvar" + "fmt" + "io" + "slices" + "strings" +) + +// Set is a string-to-Var map variable that satisfies the expvar.Var +// interface. +// +// Semantically, this is mapped by tsweb's Prometheus exporter as a +// collection of unrelated variables exported with a common prefix. +// +// This lets us have tsweb recognize *expvar.Map for different +// purposes in the future. (Or perhaps all uses of expvar.Map will +// require explicit types like this one, declaring how we want tsweb +// to export it to Prometheus.) +type Set struct { + expvar.Map +} + +// LabelMap is a string-to-Var map variable that satisfies the +// expvar.Var interface. +// +// Semantically, this is mapped by tsweb's Prometheus exporter as a +// collection of variables with the same name, with a varying label +// value. Use this to export things that are intuitively breakdowns +// into different buckets. +type LabelMap struct { + Label string + expvar.Map +} + +// SetInt64 sets the *Int value stored under the given map key. +func (m *LabelMap) SetInt64(key string, v int64) { + m.Get(key).Set(v) +} + +// Get returns a direct pointer to the expvar.Int for key, creating it +// if necessary. +func (m *LabelMap) Get(key string) *expvar.Int { + m.Add(key, 0) + return m.Map.Get(key).(*expvar.Int) +} + +// GetIncrFunc returns a function that increments the expvar.Int named by key. +// +// Most callers should not need this; it exists to satisfy an +// interface elsewhere. +func (m *LabelMap) GetIncrFunc(key string) func(delta int64) { + return m.Get(key).Add +} + +// GetFloat returns a direct pointer to the expvar.Float for key, creating it +// if necessary. +func (m *LabelMap) GetFloat(key string) *expvar.Float { + m.AddFloat(key, 0.0) + return m.Map.Get(key).(*expvar.Float) +} + +// CurrentFDs reports how many file descriptors are currently open. +// +// It only works on Linux. It returns zero otherwise. +func CurrentFDs() int { + return currentFDs() +} + +// Histogram is a histogram of values. +// It should be created with NewHistogram. +type Histogram struct { + // buckets is a list of bucket boundaries, in increasing order. + buckets []float64 + + // bucketStrings is a list of the same buckets, but as strings. + // This are allocated once at creation time by NewHistogram. + bucketStrings []string + + bucketVars []expvar.Int + sum expvar.Float + count expvar.Int +} + +// NewHistogram returns a new histogram that reports to the given +// expvar map under the given name. +// +// The buckets are the boundaries of the histogram buckets, in +// increasing order. The last bucket is +Inf. +func NewHistogram(buckets []float64) *Histogram { + if !slices.IsSorted(buckets) { + panic("buckets must be sorted") + } + labels := make([]string, len(buckets)) + for i, b := range buckets { + labels[i] = fmt.Sprintf("%v", b) + } + h := &Histogram{ + buckets: buckets, + bucketStrings: labels, + bucketVars: make([]expvar.Int, len(buckets)), + } + return h +} + +// Observe records a new observation in the histogram. +func (h *Histogram) Observe(v float64) { + h.sum.Add(v) + h.count.Add(1) + for i, b := range h.buckets { + if v <= b { + h.bucketVars[i].Add(1) + } + } +} + +// String returns a JSON representation of the histogram. +// This is used to satisfy the expvar.Var interface. +func (h *Histogram) String() string { + var b strings.Builder + fmt.Fprintf(&b, "{") + first := true + h.Do(func(kv expvar.KeyValue) { + if !first { + fmt.Fprintf(&b, ",") + } + fmt.Fprintf(&b, "%q: ", kv.Key) + if kv.Value != nil { + fmt.Fprintf(&b, "%v", kv.Value) + } else { + fmt.Fprint(&b, "null") + } + first = false + }) + fmt.Fprintf(&b, ",\"sum\": %v", &h.sum) + fmt.Fprintf(&b, ",\"count\": %v", &h.count) + fmt.Fprintf(&b, "}") + return b.String() +} + +// Do calls f for each bucket in the histogram. +func (h *Histogram) Do(f func(expvar.KeyValue)) { + for i := range h.bucketVars { + f(expvar.KeyValue{Key: h.bucketStrings[i], Value: &h.bucketVars[i]}) + } + f(expvar.KeyValue{Key: "+Inf", Value: &h.count}) +} + +// PromExport writes the histogram to w in Prometheus exposition format. +func (h *Histogram) PromExport(w io.Writer, name string) { + fmt.Fprintf(w, "# TYPE %s histogram\n", name) + h.Do(func(kv expvar.KeyValue) { + fmt.Fprintf(w, "%s_bucket{le=%q} %v\n", name, kv.Key, kv.Value) + }) + fmt.Fprintf(w, "%s_sum %v\n", name, &h.sum) + fmt.Fprintf(w, "%s_count %v\n", name, &h.count) +} diff --git a/net/art/art_test.go b/net/art/art_test.go index e3a427107..daf8553ca 100644 --- a/net/art/art_test.go +++ b/net/art/art_test.go @@ -1,20 +1,20 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package art - -import ( - "os" - "testing" - - "tailscale.com/util/cibuild" -) - -func TestMain(m *testing.M) { - if cibuild.On() { - // Skip CI on GitHub for now - // TODO: https://github.com/tailscale/tailscale/issues/7866 - os.Exit(0) - } - os.Exit(m.Run()) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package art + +import ( + "os" + "testing" + + "tailscale.com/util/cibuild" +) + +func TestMain(m *testing.M) { + if cibuild.On() { + // Skip CI on GitHub for now + // TODO: https://github.com/tailscale/tailscale/issues/7866 + os.Exit(0) + } + os.Exit(m.Run()) +} diff --git a/net/art/table.go b/net/art/table.go index 2e130d82f..fa3975778 100644 --- a/net/art/table.go +++ b/net/art/table.go @@ -1,641 +1,641 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package art provides a routing table that implements the Allotment Routing -// Table (ART) algorithm by Donald Knuth, as described in the paper by Yoichi -// Hariguchi. -// -// ART outperforms the traditional radix tree implementations for route lookups, -// insertions, and deletions. -// -// For more information, see Yoichi Hariguchi's paper: -// https://cseweb.ucsd.edu//~varghese/TEACH/cs228/artlookup.pdf -package art - -import ( - "bytes" - "encoding/binary" - "fmt" - "io" - "math/bits" - "net/netip" - "strings" - "sync" -) - -const ( - debugInsert = false - debugDelete = false -) - -// Table is an IPv4 and IPv6 routing table. -type Table[T any] struct { - v4 strideTable[T] - v6 strideTable[T] - initOnce sync.Once -} - -func (t *Table[T]) init() { - t.initOnce.Do(func() { - t.v4.prefix = netip.PrefixFrom(netip.IPv4Unspecified(), 0) - t.v6.prefix = netip.PrefixFrom(netip.IPv6Unspecified(), 0) - }) -} - -func (t *Table[T]) tableForAddr(addr netip.Addr) *strideTable[T] { - if addr.Is6() { - return &t.v6 - } - return &t.v4 -} - -// Get does a route lookup for addr and returns the associated value, or nil if -// no route matched. -func (t *Table[T]) Get(addr netip.Addr) (ret T, ok bool) { - t.init() - - // Ideally we would use addr.AsSlice here, but AsSlice is just - // barely complex enough that it can't be inlined, and that in - // turn causes the slice to escape to the heap. Using As16 and - // manual slicing here helps the compiler keep Get alloc-free. - st := t.tableForAddr(addr) - rawAddr := addr.As16() - bs := rawAddr[:] - if addr.Is4() { - bs = bs[12:] - } - - i := 0 - // With path compression, we might skip over some address bits while walking - // to a strideTable leaf. This means the leaf answer we find might not be - // correct, because path compression took us down the wrong subtree. When - // that happens, we have to backtrack and figure out which most specific - // route further up the tree is relevant to addr, and return that. - // - // So, as we walk down the stride tables, each time we find a non-nil route - // result, we have to remember it and the associated strideTable prefix. - // - // We could also deal with this edge case of path compression by checking - // the strideTable prefix on each table as we descend, but that means we - // have to pay N prefix.Contains checks on every route lookup (where N is - // the number of strideTables in the path), rather than only paying M prefix - // comparisons in the edge case (where M is the number of strideTables in - // the path with a non-nil route of their own). - const maxDepth = 16 - type prefixAndRoute struct { - prefix netip.Prefix - route T - } - strideMatch := make([]prefixAndRoute, 0, maxDepth) -findLeaf: - for { - rt, rtOK, child := st.getValAndChild(bs[i]) - if rtOK { - // This strideTable contains a route that may be relevant to our - // search, remember it. - strideMatch = append(strideMatch, prefixAndRoute{st.prefix, rt}) - } - if child == nil { - // No sub-routes further down, the last thing we recorded - // in strideRoutes is tentatively the result, barring - // misdirection from path compression. - break findLeaf - } - st = child - // Path compression means we may be skipping over some intermediate - // tables. We have to skip forward to whatever depth st now references. - i = st.prefix.Bits() / 8 - } - - // Walk backwards through the hits we recorded in strideRoutes and - // stridePrefixes, returning the first one whose subtree matches addr. - // - // In the common case where path compression did not mislead us, we'll - // return on the first loop iteration because the last route we recorded was - // the correct most-specific route. - for i := len(strideMatch) - 1; i >= 0; i-- { - if m := strideMatch[i]; m.prefix.Contains(addr) { - return m.route, true - } - } - - // We either found no route hits at all (both previous loops terminated - // immediately), or we went on a wild goose chase down a compressed path for - // the wrong prefix, and also found no usable routes on the way back up to - // the root. This is a miss. - return ret, false -} - -// Insert adds pfx to the table, with value val. -// If pfx is already present in the table, its value is set to val. -func (t *Table[T]) Insert(pfx netip.Prefix, val T) { - t.init() - - // The standard library doesn't enforce normalized prefixes (where - // the non-prefix bits are all zero). These algorithms require - // normalized prefixes, so do it upfront. - pfx = pfx.Masked() - - if debugInsert { - defer func() { - fmt.Printf("%s", t.debugSummary()) - }() - fmt.Printf("\ninsert: start pfx=%s\n", pfx) - } - - st := t.tableForAddr(pfx.Addr()) - - // This algorithm is full of off-by-one headaches that boil down - // to the fact that pfx.Bits() has (2^n)+1 values, rather than - // just 2^n. For example, an IPv4 prefix length can be 0 through - // 32, which is 33 values. - // - // This extra possible value creates a lot of problems as we do - // bits and bytes math to traverse strideTables below. So, we - // treat the default route 0/0 specially here, that way the rest - // of the logic goes back to having 2^n values to reason about, - // which can be done in a nice and regular fashion with no edge - // cases. - if pfx.Bits() == 0 { - if debugInsert { - fmt.Printf("insert: default route\n") - } - st.insert(0, 0, val) - return - } - - // No matter what we do as we traverse strideTables, our final - // action will be to insert the last 1-8 bits of pfx into a - // strideTable somewhere. - // - // We calculate upfront the byte position of the end of the - // prefix; the number of bits within that byte that contain prefix - // data; and the prefix of the strideTable into which we'll - // eventually insert. - // - // We need this in a couple different branches of the code below, - // and because the possible values are 1-indexed (1 through 32 for - // ipv4, 1 through 128 for ipv6), the math is very slightly - // unusual to account for the off-by-one indexing. Do it once up - // here, with this large comment, rather than reproduce the subtle - // math in multiple places further down. - finalByteIdx := (pfx.Bits() - 1) / 8 - finalBits := pfx.Bits() - (finalByteIdx * 8) - finalStridePrefix, err := pfx.Addr().Prefix(finalByteIdx * 8) - if err != nil { - panic(fmt.Sprintf("invalid prefix requested: %s/%d", pfx.Addr(), finalByteIdx*8)) - } - if debugInsert { - fmt.Printf("insert: finalByteIdx=%d finalBits=%d finalStridePrefix=%s\n", finalByteIdx, finalBits, finalStridePrefix) - } - - // The strideTable we want to insert into is potentially at the - // end of a chain of strideTables, each one encoding 8 bits of the - // prefix. - // - // We're expecting to walk down a path of tables, although with - // prefix compression we may end up skipping some links in the - // chain, or taking wrong turns and having to course correct. - // - // As we walk down the tree, byteIdx is the byte of bs we're - // currently examining to choose our next step, and numBits is the - // number of bits that remain in pfx, starting with the byte at - // byteIdx inclusive. - bs := pfx.Addr().AsSlice() - byteIdx := 0 - numBits := pfx.Bits() - for { - if debugInsert { - fmt.Printf("insert: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix) - } - if numBits <= 8 { - if debugInsert { - fmt.Printf("insert: existing leaf st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits) - } - // We've reached the end of the prefix, whichever - // strideTable we're looking at now is the place where we - // need to insert. - st.insert(bs[finalByteIdx], finalBits, val) - return - } - - // Otherwise, we need to go down at least one more level of - // strideTables. With prefix compression, each level of - // descent can have one of three outcomes: we find a place - // where prefix compression is possible; a place where prefix - // compression made us take a "wrong turn"; or a point along - // our intended path that we have to keep following. - child, created := st.getOrCreateChild(bs[byteIdx]) - switch { - case created: - // The subtree we need for pfx doesn't exist yet. The rest - // of the path, if we were to create it, will consist of a - // bunch of strideTables with a single child each. We can - // use path compression to elide those intermediates, and - // jump straight to the final strideTable that hosts this - // prefix. - child.prefix = finalStridePrefix - child.insert(bs[finalByteIdx], finalBits, val) - if debugInsert { - fmt.Printf("insert: new leaf st.prefix=%s child.prefix=%s addr=%d/%d\n", st.prefix, child.prefix, bs[finalByteIdx], finalBits) - } - return - case !prefixStrictlyContains(child.prefix, pfx): - // child already exists, but its prefix does not contain - // our destination. This means that the path between st - // and child was compressed by a previous insertion, and - // somewhere in the (implicit) compressed path we took a - // wrong turn, into the wrong part of st's subtree. - // - // This is okay, because pfx and child.prefix must have a - // common ancestor node somewhere between st and child. We - // can figure out what node that is, and materialize it. - // - // Once we've done that, we can immediately complete the - // remainder of the insertion in one of two ways, without - // further traversal. See a little further down for what - // those are. - if debugInsert { - fmt.Printf("insert: wrong turn, pfx=%s child.prefix=%s\n", pfx, child.prefix) - } - intermediatePrefix, addrOfExisting, addrOfNew := computePrefixSplit(child.prefix, pfx) - intermediate := &strideTable[T]{prefix: intermediatePrefix} // TODO: make this whole thing be st.AddIntermediate or something? - st.setChild(bs[byteIdx], intermediate) - intermediate.setChild(addrOfExisting, child) - - if debugInsert { - fmt.Printf("insert: new intermediate st.prefix=%s intermediate.prefix=%s child.prefix=%s\n", st.prefix, intermediate.prefix, child.prefix) - } - - // Now, we have a chain of st -> intermediate -> child. - // - // pfx either lives in a different child of intermediate, - // or in intermediate itself. For example, if we created - // the intermediate 1.2.0.0/16, pfx=1.2.3.4/32 would have - // to go into a new child of intermediate, but - // pfx=1.2.0.0/18 would go into intermediate directly. - if remain := pfx.Bits() - intermediate.prefix.Bits(); remain <= 8 { - // pfx lives in intermediate. - if debugInsert { - fmt.Printf("insert: into intermediate intermediate.prefix=%s addr=%d/%d\n", intermediate.prefix, bs[finalByteIdx], finalBits) - } - intermediate.insert(bs[finalByteIdx], finalBits, val) - } else { - // pfx lives in a different child subtree of - // intermediate. By definition this subtree doesn't - // exist at all, otherwise we'd never have entered - // this entire "wrong turn" codepath in the first - // place. - // - // This means we can apply prefix compression as we - // create this new child, and we're done. - st, created = intermediate.getOrCreateChild(addrOfNew) - if !created { - panic("new child path unexpectedly exists during path decompression") - } - st.prefix = finalStridePrefix - st.insert(bs[finalByteIdx], finalBits, val) - if debugInsert { - fmt.Printf("insert: new child st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits) - } - } - - return - default: - // An expected child table exists along pfx's - // path. Continue traversing downwards. - st = child - byteIdx = child.prefix.Bits() / 8 - numBits = pfx.Bits() - child.prefix.Bits() - if debugInsert { - fmt.Printf("insert: descend st.prefix=%s\n", st.prefix) - } - } - } -} - -// Delete removes pfx from the table, if it is present. -func (t *Table[T]) Delete(pfx netip.Prefix) { - t.init() - - // The standard library doesn't enforce normalized prefixes (where - // the non-prefix bits are all zero). These algorithms require - // normalized prefixes, so do it upfront. - pfx = pfx.Masked() - - if debugDelete { - defer func() { - fmt.Printf("%s", t.debugSummary()) - }() - fmt.Printf("\ndelete: start pfx=%s table:\n%s", pfx, t.debugSummary()) - } - - st := t.tableForAddr(pfx.Addr()) - - // This algorithm is full of off-by-one headaches, just like - // Insert. See the comment in Insert for more details. Bottom - // line: we handle the default route as a special case, and that - // simplifies the rest of the code slightly. - if pfx.Bits() == 0 { - if debugDelete { - fmt.Printf("delete: default route\n") - } - st.delete(0, 0) - return - } - - // Deletion may drive the refcount of some strideTables down to - // zero. We need to clean up these dangling tables, so we have to - // keep track of which tables we touch on the way down, and which - // strideEntry index each child is registered in. - // - // Note that the strideIndex and strideTables entries are off-by-one. - // The child table pointer is recorded at i+1, but it is referenced by a - // particular index in the parent table, at index i. - // - // In other words: entry number strideIndexes[0] in - // strideTables[0] is the same pointer as strideTables[1]. - // - // This results in some slightly odd array accesses further down - // in this code, because in a single loop iteration we have to - // write to strideTables[N] and strideIndexes[N-1]. - strideIdx := 0 - strideTables := [16]*strideTable[T]{st} - strideIndexes := [15]uint8{} - - // Similar to Insert, navigate down the tree of strideTables, - // looking for the one that houses this prefix. This part is - // easier than with insertion, since we can bail if the path ends - // early or takes an unexpected detour. However, unlike - // insertion, there's a whole post-deletion cleanup phase later - // on. - // - // As we walk down the tree, byteIdx is the byte of bs we're - // currently examining to choose our next step, and numBits is the - // number of bits that remain in pfx, starting with the byte at - // byteIdx inclusive. - bs := pfx.Addr().AsSlice() - byteIdx := 0 - numBits := pfx.Bits() - for numBits > 8 { - if debugDelete { - fmt.Printf("delete: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix) - } - child := st.getChild(bs[byteIdx]) - if child == nil { - // Prefix can't exist in the table, because one of the - // necessary strideTables doesn't exist. - if debugDelete { - fmt.Printf("delete: missing necessary child pfx=%s\n", pfx) - } - return - } - strideIndexes[strideIdx] = bs[byteIdx] - strideTables[strideIdx+1] = child - strideIdx++ - - // Path compression means byteIdx can jump forwards - // unpredictably. Recompute the next byte to look at from the - // child we just found. - byteIdx = child.prefix.Bits() / 8 - numBits = pfx.Bits() - child.prefix.Bits() - st = child - - if debugDelete { - fmt.Printf("delete: descend st.prefix=%s\n", st.prefix) - } - } - - // We reached a leaf stride table that seems to be in the right - // spot. But path compression might have led us to the wrong - // table. - if !prefixStrictlyContains(st.prefix, pfx) { - // Wrong table, the requested prefix can't exist since its - // path led us to the wrong place. - if debugDelete { - fmt.Printf("delete: wrong leaf table pfx=%s\n", pfx) - } - return - } - if debugDelete { - fmt.Printf("delete: delete from st.prefix=%s addr=%d/%d\n", st.prefix, bs[byteIdx], numBits) - } - if routeExisted := st.delete(bs[byteIdx], numBits); !routeExisted { - // We're in the right strideTable, but pfx wasn't in - // it. Refcounts haven't changed, so we can skip cleanup. - if debugDelete { - fmt.Printf("delete: prefix not present pfx=%s\n", pfx) - } - return - } - - // st.delete reduced st's refcount by one. This table may now be - // reclaimable, and depending on how we can reclaim it, the parent - // tables may also need to be reclaimed. This loop ends as soon as - // an iteration takes no action, or takes an action that doesn't - // alter the parent table's refcounts. - // - // We start our walk back at strideTables[strideIdx], which - // contains st. - for strideIdx > 0 { - cur := strideTables[strideIdx] - if debugDelete { - fmt.Printf("delete: GC? strideIdx=%d st.prefix=%s\n", strideIdx, cur.prefix) - } - if cur.routeRefs > 0 { - // the strideTable has other route entries, it cannot be - // deleted or compacted. - if debugDelete { - fmt.Printf("delete: has other routes st.prefix=%s\n", cur.prefix) - } - return - } - switch cur.childRefs { - case 0: - // no routeRefs and no childRefs, this table can be - // deleted. This will alter the parent table's refcount, - // so we'll have to look at it as well (in the next loop - // iteration). - if debugDelete { - fmt.Printf("delete: remove st.prefix=%s\n", cur.prefix) - } - strideTables[strideIdx-1].deleteChild(strideIndexes[strideIdx-1]) - strideIdx-- - case 1: - // This table has no routes, and a single child. Compact - // this table out of existence by making the parent point - // directly at the one child. This does not affect the - // parent's refcounts, so the parent can't be eligible for - // deletion or compaction, and we can stop. - child := strideTables[strideIdx].findFirstChild() // only 1 child exists, by definition - parent := strideTables[strideIdx-1] - if debugDelete { - fmt.Printf("delete: compact parent.prefix=%s st.prefix=%s child.prefix=%s\n", parent.prefix, cur.prefix, child.prefix) - } - strideTables[strideIdx-1].setChild(strideIndexes[strideIdx-1], child) - return - default: - // This table has two or more children, so it's acting as a "fork in - // the road" between two prefix subtrees. It cannot be deleted, and - // thus no further cleanups are possible. - if debugDelete { - fmt.Printf("delete: fork table st.prefix=%s\n", cur.prefix) - } - return - } - } -} - -// debugSummary prints the tree of allocated strideTables in t, with each -// strideTable's refcount. -func (t *Table[T]) debugSummary() string { - t.init() - var ret bytes.Buffer - fmt.Fprintf(&ret, "v4: ") - strideSummary(&ret, &t.v4, 4) - fmt.Fprintf(&ret, "v6: ") - strideSummary(&ret, &t.v6, 4) - return ret.String() -} - -func strideSummary[T any](w io.Writer, st *strideTable[T], indent int) { - fmt.Fprintf(w, "%s: %d routes, %d children\n", st.prefix, st.routeRefs, st.childRefs) - indent += 4 - st.treeDebugStringRec(w, 1, indent) - for addr, child := range st.children { - if child == nil { - continue - } - fmt.Fprintf(w, "%s%d/8 (%02x/8): ", strings.Repeat(" ", indent), addr, addr) - strideSummary(w, child, indent) - } -} - -// prefixStrictlyContains reports whether child is a prefix within -// parent, but not parent itself. -func prefixStrictlyContains(parent, child netip.Prefix) bool { - return parent.Overlaps(child) && parent.Bits() < child.Bits() -} - -// computePrefixSplit returns the smallest common prefix that contains -// both a and b. lastCommon is 8-bit aligned, with aStride and bStride -// indicating the value of the 8-bit stride immediately following -// lastCommon. -// -// computePrefixSplit is used in constructing an intermediate -// strideTable when a new prefix needs to be inserted in a compressed -// table. It can be read as: given that a is already in the table, and -// b is being inserted, what is the prefix of the new intermediate -// strideTable that needs to be created, and at what addresses in that -// new strideTable should a and b's subsequent strideTables be -// attached? -// -// Note as a special case, this can be called with a==b. An example of -// when this happens: -// - We want to insert the prefix 1.2.0.0/16 -// - A strideTable exists for 1.2.0.0/16, because another child -// prefix already exists (e.g. 1.2.3.4/32) -// - The 1.0.0.0/8 strideTable does not exist, because path -// compression removed it. -// -// In this scenario, the caller of computePrefixSplit ends up making a -// "wrong turn" while traversing strideTables: it was looking for the -// 1.0.0.0/8 table, but ended up at the 1.2.0.0/16 table. When this -// happens, it will invoke computePrefixSplit(1.2.0.0/16, 1.2.0.0/16), -// and we return 1.0.0.0/8 as the missing intermediate. -func computePrefixSplit(a, b netip.Prefix) (lastCommon netip.Prefix, aStride, bStride uint8) { - a = a.Masked() - b = b.Masked() - if a.Bits() == 0 || b.Bits() == 0 { - panic("computePrefixSplit called with a default route") - } - if a.Addr().Is4() != b.Addr().Is4() { - panic("computePrefixSplit called with mismatched address families") - } - - minPrefixLen := a.Bits() - if b.Bits() < minPrefixLen { - minPrefixLen = b.Bits() - } - - commonBits := commonBits(a.Addr(), b.Addr(), minPrefixLen) - // We want to know how many 8-bit strides are shared between a and - // b. Naively, this would be commonBits/8, but this introduces an - // off-by-one error. This is due to the way our ART stores - // prefixes whose length falls exactly on a stride boundary. - // - // Consider 192.168.1.0/24 and 192.168.0.0/16. commonBits - // correctly reports that these prefixes have their first 16 bits - // in common. However, in the ART they only share 1 common stride: - // they both use the 192.0.0.0/8 strideTable, but 192.168.0.0/16 - // is stored as 168/8 within that table, and not as 0/0 in the - // 192.168.0.0/16 table. - // - // So, when commonBits matches the length of one of the inputs and - // falls on a boundary between strides, the strideTable one - // further up from commonBits/8 is the one we need to create, - // which means we have to adjust the stride count down by one. - if commonBits == minPrefixLen { - commonBits-- - } - commonStrides := commonBits / 8 - lastCommon, err := a.Addr().Prefix(commonStrides * 8) - if err != nil { - panic(fmt.Sprintf("computePrefixSplit constructing common prefix: %v", err)) - } - if a.Addr().Is4() { - aStride = a.Addr().As4()[commonStrides] - bStride = b.Addr().As4()[commonStrides] - } else { - aStride = a.Addr().As16()[commonStrides] - bStride = b.Addr().As16()[commonStrides] - } - return lastCommon, aStride, bStride -} - -// commonBits returns the number of common leading bits of a and b. -// If the number of common bits exceeds maxBits, it returns maxBits -// instead. -func commonBits(a, b netip.Addr, maxBits int) int { - if a.Is4() != b.Is4() { - panic("commonStrides called with mismatched address families") - } - var common int - // The following implements an old bit-twiddling trick to compute - // the number of common leading bits: if you XOR two numbers - // together, equal bits become 0 and unequal bits become 1. You - // can then count the number of leading zeros (which is a single - // instruction on modern CPUs) to get the answer. - // - // This code is a little more complex than just XOR + count - // leading zeros, because IPv4 and IPv6 are different sizes, and - // for IPv6 we have to do the math in two 64-bit chunks because Go - // lacks a uint128 type. - if a.Is4() { - aNum, bNum := ipv4AsUint(a), ipv4AsUint(b) - common = bits.LeadingZeros32(aNum ^ bNum) - } else { - aNumHi, aNumLo := ipv6AsUint(a) - bNumHi, bNumLo := ipv6AsUint(b) - common = bits.LeadingZeros64(aNumHi ^ bNumHi) - if common == 64 { - common += bits.LeadingZeros64(aNumLo ^ bNumLo) - } - } - if common > maxBits { - common = maxBits - } - return common -} - -// ipv4AsUint returns ip as a uint32. -func ipv4AsUint(ip netip.Addr) uint32 { - bs := ip.As4() - return binary.BigEndian.Uint32(bs[:]) -} - -// ipv6AsUint returns ip as a pair of uint64s. -func ipv6AsUint(ip netip.Addr) (uint64, uint64) { - bs := ip.As16() - return binary.BigEndian.Uint64(bs[:8]), binary.BigEndian.Uint64(bs[8:]) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package art provides a routing table that implements the Allotment Routing +// Table (ART) algorithm by Donald Knuth, as described in the paper by Yoichi +// Hariguchi. +// +// ART outperforms the traditional radix tree implementations for route lookups, +// insertions, and deletions. +// +// For more information, see Yoichi Hariguchi's paper: +// https://cseweb.ucsd.edu//~varghese/TEACH/cs228/artlookup.pdf +package art + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math/bits" + "net/netip" + "strings" + "sync" +) + +const ( + debugInsert = false + debugDelete = false +) + +// Table is an IPv4 and IPv6 routing table. +type Table[T any] struct { + v4 strideTable[T] + v6 strideTable[T] + initOnce sync.Once +} + +func (t *Table[T]) init() { + t.initOnce.Do(func() { + t.v4.prefix = netip.PrefixFrom(netip.IPv4Unspecified(), 0) + t.v6.prefix = netip.PrefixFrom(netip.IPv6Unspecified(), 0) + }) +} + +func (t *Table[T]) tableForAddr(addr netip.Addr) *strideTable[T] { + if addr.Is6() { + return &t.v6 + } + return &t.v4 +} + +// Get does a route lookup for addr and returns the associated value, or nil if +// no route matched. +func (t *Table[T]) Get(addr netip.Addr) (ret T, ok bool) { + t.init() + + // Ideally we would use addr.AsSlice here, but AsSlice is just + // barely complex enough that it can't be inlined, and that in + // turn causes the slice to escape to the heap. Using As16 and + // manual slicing here helps the compiler keep Get alloc-free. + st := t.tableForAddr(addr) + rawAddr := addr.As16() + bs := rawAddr[:] + if addr.Is4() { + bs = bs[12:] + } + + i := 0 + // With path compression, we might skip over some address bits while walking + // to a strideTable leaf. This means the leaf answer we find might not be + // correct, because path compression took us down the wrong subtree. When + // that happens, we have to backtrack and figure out which most specific + // route further up the tree is relevant to addr, and return that. + // + // So, as we walk down the stride tables, each time we find a non-nil route + // result, we have to remember it and the associated strideTable prefix. + // + // We could also deal with this edge case of path compression by checking + // the strideTable prefix on each table as we descend, but that means we + // have to pay N prefix.Contains checks on every route lookup (where N is + // the number of strideTables in the path), rather than only paying M prefix + // comparisons in the edge case (where M is the number of strideTables in + // the path with a non-nil route of their own). + const maxDepth = 16 + type prefixAndRoute struct { + prefix netip.Prefix + route T + } + strideMatch := make([]prefixAndRoute, 0, maxDepth) +findLeaf: + for { + rt, rtOK, child := st.getValAndChild(bs[i]) + if rtOK { + // This strideTable contains a route that may be relevant to our + // search, remember it. + strideMatch = append(strideMatch, prefixAndRoute{st.prefix, rt}) + } + if child == nil { + // No sub-routes further down, the last thing we recorded + // in strideRoutes is tentatively the result, barring + // misdirection from path compression. + break findLeaf + } + st = child + // Path compression means we may be skipping over some intermediate + // tables. We have to skip forward to whatever depth st now references. + i = st.prefix.Bits() / 8 + } + + // Walk backwards through the hits we recorded in strideRoutes and + // stridePrefixes, returning the first one whose subtree matches addr. + // + // In the common case where path compression did not mislead us, we'll + // return on the first loop iteration because the last route we recorded was + // the correct most-specific route. + for i := len(strideMatch) - 1; i >= 0; i-- { + if m := strideMatch[i]; m.prefix.Contains(addr) { + return m.route, true + } + } + + // We either found no route hits at all (both previous loops terminated + // immediately), or we went on a wild goose chase down a compressed path for + // the wrong prefix, and also found no usable routes on the way back up to + // the root. This is a miss. + return ret, false +} + +// Insert adds pfx to the table, with value val. +// If pfx is already present in the table, its value is set to val. +func (t *Table[T]) Insert(pfx netip.Prefix, val T) { + t.init() + + // The standard library doesn't enforce normalized prefixes (where + // the non-prefix bits are all zero). These algorithms require + // normalized prefixes, so do it upfront. + pfx = pfx.Masked() + + if debugInsert { + defer func() { + fmt.Printf("%s", t.debugSummary()) + }() + fmt.Printf("\ninsert: start pfx=%s\n", pfx) + } + + st := t.tableForAddr(pfx.Addr()) + + // This algorithm is full of off-by-one headaches that boil down + // to the fact that pfx.Bits() has (2^n)+1 values, rather than + // just 2^n. For example, an IPv4 prefix length can be 0 through + // 32, which is 33 values. + // + // This extra possible value creates a lot of problems as we do + // bits and bytes math to traverse strideTables below. So, we + // treat the default route 0/0 specially here, that way the rest + // of the logic goes back to having 2^n values to reason about, + // which can be done in a nice and regular fashion with no edge + // cases. + if pfx.Bits() == 0 { + if debugInsert { + fmt.Printf("insert: default route\n") + } + st.insert(0, 0, val) + return + } + + // No matter what we do as we traverse strideTables, our final + // action will be to insert the last 1-8 bits of pfx into a + // strideTable somewhere. + // + // We calculate upfront the byte position of the end of the + // prefix; the number of bits within that byte that contain prefix + // data; and the prefix of the strideTable into which we'll + // eventually insert. + // + // We need this in a couple different branches of the code below, + // and because the possible values are 1-indexed (1 through 32 for + // ipv4, 1 through 128 for ipv6), the math is very slightly + // unusual to account for the off-by-one indexing. Do it once up + // here, with this large comment, rather than reproduce the subtle + // math in multiple places further down. + finalByteIdx := (pfx.Bits() - 1) / 8 + finalBits := pfx.Bits() - (finalByteIdx * 8) + finalStridePrefix, err := pfx.Addr().Prefix(finalByteIdx * 8) + if err != nil { + panic(fmt.Sprintf("invalid prefix requested: %s/%d", pfx.Addr(), finalByteIdx*8)) + } + if debugInsert { + fmt.Printf("insert: finalByteIdx=%d finalBits=%d finalStridePrefix=%s\n", finalByteIdx, finalBits, finalStridePrefix) + } + + // The strideTable we want to insert into is potentially at the + // end of a chain of strideTables, each one encoding 8 bits of the + // prefix. + // + // We're expecting to walk down a path of tables, although with + // prefix compression we may end up skipping some links in the + // chain, or taking wrong turns and having to course correct. + // + // As we walk down the tree, byteIdx is the byte of bs we're + // currently examining to choose our next step, and numBits is the + // number of bits that remain in pfx, starting with the byte at + // byteIdx inclusive. + bs := pfx.Addr().AsSlice() + byteIdx := 0 + numBits := pfx.Bits() + for { + if debugInsert { + fmt.Printf("insert: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix) + } + if numBits <= 8 { + if debugInsert { + fmt.Printf("insert: existing leaf st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits) + } + // We've reached the end of the prefix, whichever + // strideTable we're looking at now is the place where we + // need to insert. + st.insert(bs[finalByteIdx], finalBits, val) + return + } + + // Otherwise, we need to go down at least one more level of + // strideTables. With prefix compression, each level of + // descent can have one of three outcomes: we find a place + // where prefix compression is possible; a place where prefix + // compression made us take a "wrong turn"; or a point along + // our intended path that we have to keep following. + child, created := st.getOrCreateChild(bs[byteIdx]) + switch { + case created: + // The subtree we need for pfx doesn't exist yet. The rest + // of the path, if we were to create it, will consist of a + // bunch of strideTables with a single child each. We can + // use path compression to elide those intermediates, and + // jump straight to the final strideTable that hosts this + // prefix. + child.prefix = finalStridePrefix + child.insert(bs[finalByteIdx], finalBits, val) + if debugInsert { + fmt.Printf("insert: new leaf st.prefix=%s child.prefix=%s addr=%d/%d\n", st.prefix, child.prefix, bs[finalByteIdx], finalBits) + } + return + case !prefixStrictlyContains(child.prefix, pfx): + // child already exists, but its prefix does not contain + // our destination. This means that the path between st + // and child was compressed by a previous insertion, and + // somewhere in the (implicit) compressed path we took a + // wrong turn, into the wrong part of st's subtree. + // + // This is okay, because pfx and child.prefix must have a + // common ancestor node somewhere between st and child. We + // can figure out what node that is, and materialize it. + // + // Once we've done that, we can immediately complete the + // remainder of the insertion in one of two ways, without + // further traversal. See a little further down for what + // those are. + if debugInsert { + fmt.Printf("insert: wrong turn, pfx=%s child.prefix=%s\n", pfx, child.prefix) + } + intermediatePrefix, addrOfExisting, addrOfNew := computePrefixSplit(child.prefix, pfx) + intermediate := &strideTable[T]{prefix: intermediatePrefix} // TODO: make this whole thing be st.AddIntermediate or something? + st.setChild(bs[byteIdx], intermediate) + intermediate.setChild(addrOfExisting, child) + + if debugInsert { + fmt.Printf("insert: new intermediate st.prefix=%s intermediate.prefix=%s child.prefix=%s\n", st.prefix, intermediate.prefix, child.prefix) + } + + // Now, we have a chain of st -> intermediate -> child. + // + // pfx either lives in a different child of intermediate, + // or in intermediate itself. For example, if we created + // the intermediate 1.2.0.0/16, pfx=1.2.3.4/32 would have + // to go into a new child of intermediate, but + // pfx=1.2.0.0/18 would go into intermediate directly. + if remain := pfx.Bits() - intermediate.prefix.Bits(); remain <= 8 { + // pfx lives in intermediate. + if debugInsert { + fmt.Printf("insert: into intermediate intermediate.prefix=%s addr=%d/%d\n", intermediate.prefix, bs[finalByteIdx], finalBits) + } + intermediate.insert(bs[finalByteIdx], finalBits, val) + } else { + // pfx lives in a different child subtree of + // intermediate. By definition this subtree doesn't + // exist at all, otherwise we'd never have entered + // this entire "wrong turn" codepath in the first + // place. + // + // This means we can apply prefix compression as we + // create this new child, and we're done. + st, created = intermediate.getOrCreateChild(addrOfNew) + if !created { + panic("new child path unexpectedly exists during path decompression") + } + st.prefix = finalStridePrefix + st.insert(bs[finalByteIdx], finalBits, val) + if debugInsert { + fmt.Printf("insert: new child st.prefix=%s addr=%d/%d\n", st.prefix, bs[finalByteIdx], finalBits) + } + } + + return + default: + // An expected child table exists along pfx's + // path. Continue traversing downwards. + st = child + byteIdx = child.prefix.Bits() / 8 + numBits = pfx.Bits() - child.prefix.Bits() + if debugInsert { + fmt.Printf("insert: descend st.prefix=%s\n", st.prefix) + } + } + } +} + +// Delete removes pfx from the table, if it is present. +func (t *Table[T]) Delete(pfx netip.Prefix) { + t.init() + + // The standard library doesn't enforce normalized prefixes (where + // the non-prefix bits are all zero). These algorithms require + // normalized prefixes, so do it upfront. + pfx = pfx.Masked() + + if debugDelete { + defer func() { + fmt.Printf("%s", t.debugSummary()) + }() + fmt.Printf("\ndelete: start pfx=%s table:\n%s", pfx, t.debugSummary()) + } + + st := t.tableForAddr(pfx.Addr()) + + // This algorithm is full of off-by-one headaches, just like + // Insert. See the comment in Insert for more details. Bottom + // line: we handle the default route as a special case, and that + // simplifies the rest of the code slightly. + if pfx.Bits() == 0 { + if debugDelete { + fmt.Printf("delete: default route\n") + } + st.delete(0, 0) + return + } + + // Deletion may drive the refcount of some strideTables down to + // zero. We need to clean up these dangling tables, so we have to + // keep track of which tables we touch on the way down, and which + // strideEntry index each child is registered in. + // + // Note that the strideIndex and strideTables entries are off-by-one. + // The child table pointer is recorded at i+1, but it is referenced by a + // particular index in the parent table, at index i. + // + // In other words: entry number strideIndexes[0] in + // strideTables[0] is the same pointer as strideTables[1]. + // + // This results in some slightly odd array accesses further down + // in this code, because in a single loop iteration we have to + // write to strideTables[N] and strideIndexes[N-1]. + strideIdx := 0 + strideTables := [16]*strideTable[T]{st} + strideIndexes := [15]uint8{} + + // Similar to Insert, navigate down the tree of strideTables, + // looking for the one that houses this prefix. This part is + // easier than with insertion, since we can bail if the path ends + // early or takes an unexpected detour. However, unlike + // insertion, there's a whole post-deletion cleanup phase later + // on. + // + // As we walk down the tree, byteIdx is the byte of bs we're + // currently examining to choose our next step, and numBits is the + // number of bits that remain in pfx, starting with the byte at + // byteIdx inclusive. + bs := pfx.Addr().AsSlice() + byteIdx := 0 + numBits := pfx.Bits() + for numBits > 8 { + if debugDelete { + fmt.Printf("delete: loop byteIdx=%d numBits=%d st.prefix=%s\n", byteIdx, numBits, st.prefix) + } + child := st.getChild(bs[byteIdx]) + if child == nil { + // Prefix can't exist in the table, because one of the + // necessary strideTables doesn't exist. + if debugDelete { + fmt.Printf("delete: missing necessary child pfx=%s\n", pfx) + } + return + } + strideIndexes[strideIdx] = bs[byteIdx] + strideTables[strideIdx+1] = child + strideIdx++ + + // Path compression means byteIdx can jump forwards + // unpredictably. Recompute the next byte to look at from the + // child we just found. + byteIdx = child.prefix.Bits() / 8 + numBits = pfx.Bits() - child.prefix.Bits() + st = child + + if debugDelete { + fmt.Printf("delete: descend st.prefix=%s\n", st.prefix) + } + } + + // We reached a leaf stride table that seems to be in the right + // spot. But path compression might have led us to the wrong + // table. + if !prefixStrictlyContains(st.prefix, pfx) { + // Wrong table, the requested prefix can't exist since its + // path led us to the wrong place. + if debugDelete { + fmt.Printf("delete: wrong leaf table pfx=%s\n", pfx) + } + return + } + if debugDelete { + fmt.Printf("delete: delete from st.prefix=%s addr=%d/%d\n", st.prefix, bs[byteIdx], numBits) + } + if routeExisted := st.delete(bs[byteIdx], numBits); !routeExisted { + // We're in the right strideTable, but pfx wasn't in + // it. Refcounts haven't changed, so we can skip cleanup. + if debugDelete { + fmt.Printf("delete: prefix not present pfx=%s\n", pfx) + } + return + } + + // st.delete reduced st's refcount by one. This table may now be + // reclaimable, and depending on how we can reclaim it, the parent + // tables may also need to be reclaimed. This loop ends as soon as + // an iteration takes no action, or takes an action that doesn't + // alter the parent table's refcounts. + // + // We start our walk back at strideTables[strideIdx], which + // contains st. + for strideIdx > 0 { + cur := strideTables[strideIdx] + if debugDelete { + fmt.Printf("delete: GC? strideIdx=%d st.prefix=%s\n", strideIdx, cur.prefix) + } + if cur.routeRefs > 0 { + // the strideTable has other route entries, it cannot be + // deleted or compacted. + if debugDelete { + fmt.Printf("delete: has other routes st.prefix=%s\n", cur.prefix) + } + return + } + switch cur.childRefs { + case 0: + // no routeRefs and no childRefs, this table can be + // deleted. This will alter the parent table's refcount, + // so we'll have to look at it as well (in the next loop + // iteration). + if debugDelete { + fmt.Printf("delete: remove st.prefix=%s\n", cur.prefix) + } + strideTables[strideIdx-1].deleteChild(strideIndexes[strideIdx-1]) + strideIdx-- + case 1: + // This table has no routes, and a single child. Compact + // this table out of existence by making the parent point + // directly at the one child. This does not affect the + // parent's refcounts, so the parent can't be eligible for + // deletion or compaction, and we can stop. + child := strideTables[strideIdx].findFirstChild() // only 1 child exists, by definition + parent := strideTables[strideIdx-1] + if debugDelete { + fmt.Printf("delete: compact parent.prefix=%s st.prefix=%s child.prefix=%s\n", parent.prefix, cur.prefix, child.prefix) + } + strideTables[strideIdx-1].setChild(strideIndexes[strideIdx-1], child) + return + default: + // This table has two or more children, so it's acting as a "fork in + // the road" between two prefix subtrees. It cannot be deleted, and + // thus no further cleanups are possible. + if debugDelete { + fmt.Printf("delete: fork table st.prefix=%s\n", cur.prefix) + } + return + } + } +} + +// debugSummary prints the tree of allocated strideTables in t, with each +// strideTable's refcount. +func (t *Table[T]) debugSummary() string { + t.init() + var ret bytes.Buffer + fmt.Fprintf(&ret, "v4: ") + strideSummary(&ret, &t.v4, 4) + fmt.Fprintf(&ret, "v6: ") + strideSummary(&ret, &t.v6, 4) + return ret.String() +} + +func strideSummary[T any](w io.Writer, st *strideTable[T], indent int) { + fmt.Fprintf(w, "%s: %d routes, %d children\n", st.prefix, st.routeRefs, st.childRefs) + indent += 4 + st.treeDebugStringRec(w, 1, indent) + for addr, child := range st.children { + if child == nil { + continue + } + fmt.Fprintf(w, "%s%d/8 (%02x/8): ", strings.Repeat(" ", indent), addr, addr) + strideSummary(w, child, indent) + } +} + +// prefixStrictlyContains reports whether child is a prefix within +// parent, but not parent itself. +func prefixStrictlyContains(parent, child netip.Prefix) bool { + return parent.Overlaps(child) && parent.Bits() < child.Bits() +} + +// computePrefixSplit returns the smallest common prefix that contains +// both a and b. lastCommon is 8-bit aligned, with aStride and bStride +// indicating the value of the 8-bit stride immediately following +// lastCommon. +// +// computePrefixSplit is used in constructing an intermediate +// strideTable when a new prefix needs to be inserted in a compressed +// table. It can be read as: given that a is already in the table, and +// b is being inserted, what is the prefix of the new intermediate +// strideTable that needs to be created, and at what addresses in that +// new strideTable should a and b's subsequent strideTables be +// attached? +// +// Note as a special case, this can be called with a==b. An example of +// when this happens: +// - We want to insert the prefix 1.2.0.0/16 +// - A strideTable exists for 1.2.0.0/16, because another child +// prefix already exists (e.g. 1.2.3.4/32) +// - The 1.0.0.0/8 strideTable does not exist, because path +// compression removed it. +// +// In this scenario, the caller of computePrefixSplit ends up making a +// "wrong turn" while traversing strideTables: it was looking for the +// 1.0.0.0/8 table, but ended up at the 1.2.0.0/16 table. When this +// happens, it will invoke computePrefixSplit(1.2.0.0/16, 1.2.0.0/16), +// and we return 1.0.0.0/8 as the missing intermediate. +func computePrefixSplit(a, b netip.Prefix) (lastCommon netip.Prefix, aStride, bStride uint8) { + a = a.Masked() + b = b.Masked() + if a.Bits() == 0 || b.Bits() == 0 { + panic("computePrefixSplit called with a default route") + } + if a.Addr().Is4() != b.Addr().Is4() { + panic("computePrefixSplit called with mismatched address families") + } + + minPrefixLen := a.Bits() + if b.Bits() < minPrefixLen { + minPrefixLen = b.Bits() + } + + commonBits := commonBits(a.Addr(), b.Addr(), minPrefixLen) + // We want to know how many 8-bit strides are shared between a and + // b. Naively, this would be commonBits/8, but this introduces an + // off-by-one error. This is due to the way our ART stores + // prefixes whose length falls exactly on a stride boundary. + // + // Consider 192.168.1.0/24 and 192.168.0.0/16. commonBits + // correctly reports that these prefixes have their first 16 bits + // in common. However, in the ART they only share 1 common stride: + // they both use the 192.0.0.0/8 strideTable, but 192.168.0.0/16 + // is stored as 168/8 within that table, and not as 0/0 in the + // 192.168.0.0/16 table. + // + // So, when commonBits matches the length of one of the inputs and + // falls on a boundary between strides, the strideTable one + // further up from commonBits/8 is the one we need to create, + // which means we have to adjust the stride count down by one. + if commonBits == minPrefixLen { + commonBits-- + } + commonStrides := commonBits / 8 + lastCommon, err := a.Addr().Prefix(commonStrides * 8) + if err != nil { + panic(fmt.Sprintf("computePrefixSplit constructing common prefix: %v", err)) + } + if a.Addr().Is4() { + aStride = a.Addr().As4()[commonStrides] + bStride = b.Addr().As4()[commonStrides] + } else { + aStride = a.Addr().As16()[commonStrides] + bStride = b.Addr().As16()[commonStrides] + } + return lastCommon, aStride, bStride +} + +// commonBits returns the number of common leading bits of a and b. +// If the number of common bits exceeds maxBits, it returns maxBits +// instead. +func commonBits(a, b netip.Addr, maxBits int) int { + if a.Is4() != b.Is4() { + panic("commonStrides called with mismatched address families") + } + var common int + // The following implements an old bit-twiddling trick to compute + // the number of common leading bits: if you XOR two numbers + // together, equal bits become 0 and unequal bits become 1. You + // can then count the number of leading zeros (which is a single + // instruction on modern CPUs) to get the answer. + // + // This code is a little more complex than just XOR + count + // leading zeros, because IPv4 and IPv6 are different sizes, and + // for IPv6 we have to do the math in two 64-bit chunks because Go + // lacks a uint128 type. + if a.Is4() { + aNum, bNum := ipv4AsUint(a), ipv4AsUint(b) + common = bits.LeadingZeros32(aNum ^ bNum) + } else { + aNumHi, aNumLo := ipv6AsUint(a) + bNumHi, bNumLo := ipv6AsUint(b) + common = bits.LeadingZeros64(aNumHi ^ bNumHi) + if common == 64 { + common += bits.LeadingZeros64(aNumLo ^ bNumLo) + } + } + if common > maxBits { + common = maxBits + } + return common +} + +// ipv4AsUint returns ip as a uint32. +func ipv4AsUint(ip netip.Addr) uint32 { + bs := ip.As4() + return binary.BigEndian.Uint32(bs[:]) +} + +// ipv6AsUint returns ip as a pair of uint64s. +func ipv6AsUint(ip netip.Addr) (uint64, uint64) { + bs := ip.As16() + return binary.BigEndian.Uint64(bs[:8]), binary.BigEndian.Uint64(bs[8:]) +} diff --git a/net/dns/debian_resolvconf.go b/net/dns/debian_resolvconf.go index 2a1fb18de..3ffc796e0 100644 --- a/net/dns/debian_resolvconf.go +++ b/net/dns/debian_resolvconf.go @@ -1,184 +1,184 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux || freebsd || openbsd - -package dns - -import ( - "bufio" - "bytes" - _ "embed" - "fmt" - "os" - "os/exec" - "path/filepath" - - "tailscale.com/atomicfile" - "tailscale.com/types/logger" -) - -//go:embed resolvconf-workaround.sh -var workaroundScript []byte - -// resolvconfConfigName is the name of the config submitted to -// resolvconf. -// The name starts with 'tun' in order to match the hardcoded -// interface order in debian resolvconf, which will place this -// configuration ahead of regular network links. In theory, this -// doesn't matter because we then fix things up to ensure our config -// is the only one in use, but in case that fails, this will make our -// configuration slightly preferred. -// The 'inet' suffix has no specific meaning, but conventionally -// resolvconf implementations encourage adding a suffix roughly -// indicating where the config came from, and "inet" is the "none of -// the above" value (rather than, say, "ppp" or "dhcp"). -const resolvconfConfigName = "tun-tailscale.inet" - -// resolvconfLibcHookPath is the directory containing libc update -// scripts, which are run by Debian resolvconf when /etc/resolv.conf -// has been updated. -const resolvconfLibcHookPath = "/etc/resolvconf/update-libc.d" - -// resolvconfHookPath is the name of the libc hook script we install -// to force Tailscale's DNS config to take effect. -var resolvconfHookPath = filepath.Join(resolvconfLibcHookPath, "tailscale") - -// resolvconfManager manages DNS configuration using the Debian -// implementation of the `resolvconf` program, written by Thomas Hood. -type resolvconfManager struct { - logf logger.Logf - listRecordsPath string - interfacesDir string - scriptInstalled bool // libc update script has been installed -} - -func newDebianResolvconfManager(logf logger.Logf) (*resolvconfManager, error) { - ret := &resolvconfManager{ - logf: logf, - listRecordsPath: "/lib/resolvconf/list-records", - interfacesDir: "/etc/resolvconf/run/interface", // panic fallback if nothing seems to work - } - - if _, err := os.Stat(ret.listRecordsPath); os.IsNotExist(err) { - // This might be a Debian system from before the big /usr - // merge, try /usr instead. - ret.listRecordsPath = "/usr" + ret.listRecordsPath - } - // The runtime directory is currently (2020-04) canonically - // /etc/resolvconf/run, but the manpage is making noise about - // switching to /run/resolvconf and dropping the /etc path. So, - // let's probe the possible directories and use the first one - // that works. - for _, path := range []string{ - "/etc/resolvconf/run/interface", - "/run/resolvconf/interface", - "/var/run/resolvconf/interface", - } { - if _, err := os.Stat(path); err == nil { - ret.interfacesDir = path - break - } - } - if ret.interfacesDir == "" { - // None of the paths seem to work, use the canonical location - // that the current manpage says to use. - ret.interfacesDir = "/etc/resolvconf/run/interfaces" - } - - return ret, nil -} - -func (m *resolvconfManager) deleteTailscaleConfig() error { - cmd := exec.Command("resolvconf", "-d", resolvconfConfigName) - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("running %s: %s", cmd, out) - } - return nil -} - -func (m *resolvconfManager) SetDNS(config OSConfig) error { - if !m.scriptInstalled { - m.logf("injecting resolvconf workaround script") - if err := os.MkdirAll(resolvconfLibcHookPath, 0755); err != nil { - return err - } - if err := atomicfile.WriteFile(resolvconfHookPath, workaroundScript, 0755); err != nil { - return err - } - m.scriptInstalled = true - } - - if config.IsZero() { - if err := m.deleteTailscaleConfig(); err != nil { - return err - } - } else { - stdin := new(bytes.Buffer) - writeResolvConf(stdin, config.Nameservers, config.SearchDomains) // dns_direct.go - - // This resolvconf implementation doesn't support exclusive - // mode or interface priorities, so it will end up blending - // our configuration with other sources. However, this will - // get fixed up by the script we injected above. - cmd := exec.Command("resolvconf", "-a", resolvconfConfigName) - cmd.Stdin = stdin - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("running %s: %s", cmd, out) - } - } - - return nil -} - -func (m *resolvconfManager) SupportsSplitDNS() bool { - return false -} - -func (m *resolvconfManager) GetBaseConfig() (OSConfig, error) { - var bs bytes.Buffer - - cmd := exec.Command(m.listRecordsPath) - // list-records assumes it's being run with CWD set to the - // interfaces runtime dir, and returns nonsense otherwise. - cmd.Dir = m.interfacesDir - cmd.Stdout = &bs - if err := cmd.Run(); err != nil { - return OSConfig{}, err - } - - var conf bytes.Buffer - sc := bufio.NewScanner(&bs) - for sc.Scan() { - if sc.Text() == resolvconfConfigName { - continue - } - bs, err := os.ReadFile(filepath.Join(m.interfacesDir, sc.Text())) - if err != nil { - if os.IsNotExist(err) { - // Probably raced with a deletion, that's okay. - continue - } - return OSConfig{}, err - } - conf.Write(bs) - conf.WriteByte('\n') - } - - return readResolv(&conf) -} - -func (m *resolvconfManager) Close() error { - if err := m.deleteTailscaleConfig(); err != nil { - return err - } - - if m.scriptInstalled { - m.logf("removing resolvconf workaround script") - os.Remove(resolvconfHookPath) // Best-effort - } - - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux || freebsd || openbsd + +package dns + +import ( + "bufio" + "bytes" + _ "embed" + "fmt" + "os" + "os/exec" + "path/filepath" + + "tailscale.com/atomicfile" + "tailscale.com/types/logger" +) + +//go:embed resolvconf-workaround.sh +var workaroundScript []byte + +// resolvconfConfigName is the name of the config submitted to +// resolvconf. +// The name starts with 'tun' in order to match the hardcoded +// interface order in debian resolvconf, which will place this +// configuration ahead of regular network links. In theory, this +// doesn't matter because we then fix things up to ensure our config +// is the only one in use, but in case that fails, this will make our +// configuration slightly preferred. +// The 'inet' suffix has no specific meaning, but conventionally +// resolvconf implementations encourage adding a suffix roughly +// indicating where the config came from, and "inet" is the "none of +// the above" value (rather than, say, "ppp" or "dhcp"). +const resolvconfConfigName = "tun-tailscale.inet" + +// resolvconfLibcHookPath is the directory containing libc update +// scripts, which are run by Debian resolvconf when /etc/resolv.conf +// has been updated. +const resolvconfLibcHookPath = "/etc/resolvconf/update-libc.d" + +// resolvconfHookPath is the name of the libc hook script we install +// to force Tailscale's DNS config to take effect. +var resolvconfHookPath = filepath.Join(resolvconfLibcHookPath, "tailscale") + +// resolvconfManager manages DNS configuration using the Debian +// implementation of the `resolvconf` program, written by Thomas Hood. +type resolvconfManager struct { + logf logger.Logf + listRecordsPath string + interfacesDir string + scriptInstalled bool // libc update script has been installed +} + +func newDebianResolvconfManager(logf logger.Logf) (*resolvconfManager, error) { + ret := &resolvconfManager{ + logf: logf, + listRecordsPath: "/lib/resolvconf/list-records", + interfacesDir: "/etc/resolvconf/run/interface", // panic fallback if nothing seems to work + } + + if _, err := os.Stat(ret.listRecordsPath); os.IsNotExist(err) { + // This might be a Debian system from before the big /usr + // merge, try /usr instead. + ret.listRecordsPath = "/usr" + ret.listRecordsPath + } + // The runtime directory is currently (2020-04) canonically + // /etc/resolvconf/run, but the manpage is making noise about + // switching to /run/resolvconf and dropping the /etc path. So, + // let's probe the possible directories and use the first one + // that works. + for _, path := range []string{ + "/etc/resolvconf/run/interface", + "/run/resolvconf/interface", + "/var/run/resolvconf/interface", + } { + if _, err := os.Stat(path); err == nil { + ret.interfacesDir = path + break + } + } + if ret.interfacesDir == "" { + // None of the paths seem to work, use the canonical location + // that the current manpage says to use. + ret.interfacesDir = "/etc/resolvconf/run/interfaces" + } + + return ret, nil +} + +func (m *resolvconfManager) deleteTailscaleConfig() error { + cmd := exec.Command("resolvconf", "-d", resolvconfConfigName) + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("running %s: %s", cmd, out) + } + return nil +} + +func (m *resolvconfManager) SetDNS(config OSConfig) error { + if !m.scriptInstalled { + m.logf("injecting resolvconf workaround script") + if err := os.MkdirAll(resolvconfLibcHookPath, 0755); err != nil { + return err + } + if err := atomicfile.WriteFile(resolvconfHookPath, workaroundScript, 0755); err != nil { + return err + } + m.scriptInstalled = true + } + + if config.IsZero() { + if err := m.deleteTailscaleConfig(); err != nil { + return err + } + } else { + stdin := new(bytes.Buffer) + writeResolvConf(stdin, config.Nameservers, config.SearchDomains) // dns_direct.go + + // This resolvconf implementation doesn't support exclusive + // mode or interface priorities, so it will end up blending + // our configuration with other sources. However, this will + // get fixed up by the script we injected above. + cmd := exec.Command("resolvconf", "-a", resolvconfConfigName) + cmd.Stdin = stdin + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("running %s: %s", cmd, out) + } + } + + return nil +} + +func (m *resolvconfManager) SupportsSplitDNS() bool { + return false +} + +func (m *resolvconfManager) GetBaseConfig() (OSConfig, error) { + var bs bytes.Buffer + + cmd := exec.Command(m.listRecordsPath) + // list-records assumes it's being run with CWD set to the + // interfaces runtime dir, and returns nonsense otherwise. + cmd.Dir = m.interfacesDir + cmd.Stdout = &bs + if err := cmd.Run(); err != nil { + return OSConfig{}, err + } + + var conf bytes.Buffer + sc := bufio.NewScanner(&bs) + for sc.Scan() { + if sc.Text() == resolvconfConfigName { + continue + } + bs, err := os.ReadFile(filepath.Join(m.interfacesDir, sc.Text())) + if err != nil { + if os.IsNotExist(err) { + // Probably raced with a deletion, that's okay. + continue + } + return OSConfig{}, err + } + conf.Write(bs) + conf.WriteByte('\n') + } + + return readResolv(&conf) +} + +func (m *resolvconfManager) Close() error { + if err := m.deleteTailscaleConfig(); err != nil { + return err + } + + if m.scriptInstalled { + m.logf("removing resolvconf workaround script") + os.Remove(resolvconfHookPath) // Best-effort + } + + return nil +} diff --git a/net/dns/direct_notlinux.go b/net/dns/direct_notlinux.go index 5bd8093d6..c221ca1be 100644 --- a/net/dns/direct_notlinux.go +++ b/net/dns/direct_notlinux.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux - -package dns - -func (m *directManager) runFileWatcher() { - // Not implemented on other platforms. Maybe it could resort to polling. -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package dns + +func (m *directManager) runFileWatcher() { + // Not implemented on other platforms. Maybe it could resort to polling. +} diff --git a/net/dns/flush_default.go b/net/dns/flush_default.go index 73e446389..eb6d9da41 100644 --- a/net/dns/flush_default.go +++ b/net/dns/flush_default.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package dns - -func flushCaches() error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package dns + +func flushCaches() error { + return nil +} diff --git a/net/dns/ini.go b/net/dns/ini.go index deec04019..1e47d606e 100644 --- a/net/dns/ini.go +++ b/net/dns/ini.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build windows - -package dns - -import ( - "regexp" - "strings" -) - -// parseIni parses a basic .ini file, used for wsl.conf. -func parseIni(data string) map[string]map[string]string { - sectionRE := regexp.MustCompile(`^\[([^]]+)\]`) - kvRE := regexp.MustCompile(`^\s*(\w+)\s*=\s*([^#]*)`) - - ini := map[string]map[string]string{} - var section string - for _, line := range strings.Split(data, "\n") { - if res := sectionRE.FindStringSubmatch(line); len(res) > 1 { - section = res[1] - ini[section] = map[string]string{} - } else if res := kvRE.FindStringSubmatch(line); len(res) > 2 { - k, v := strings.TrimSpace(res[1]), strings.TrimSpace(res[2]) - ini[section][k] = v - } - } - return ini -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package dns + +import ( + "regexp" + "strings" +) + +// parseIni parses a basic .ini file, used for wsl.conf. +func parseIni(data string) map[string]map[string]string { + sectionRE := regexp.MustCompile(`^\[([^]]+)\]`) + kvRE := regexp.MustCompile(`^\s*(\w+)\s*=\s*([^#]*)`) + + ini := map[string]map[string]string{} + var section string + for _, line := range strings.Split(data, "\n") { + if res := sectionRE.FindStringSubmatch(line); len(res) > 1 { + section = res[1] + ini[section] = map[string]string{} + } else if res := kvRE.FindStringSubmatch(line); len(res) > 2 { + k, v := strings.TrimSpace(res[1]), strings.TrimSpace(res[2]) + ini[section][k] = v + } + } + return ini +} diff --git a/net/dns/ini_test.go b/net/dns/ini_test.go index 0e9eaa672..3afe7009c 100644 --- a/net/dns/ini_test.go +++ b/net/dns/ini_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build windows - -package dns - -import ( - "reflect" - "testing" -) - -func TestParseIni(t *testing.T) { - var tests = []struct { - src string - want map[string]map[string]string - }{ - { - src: `# appended wsl.conf file -[automount] - enabled = true - root=/mnt/ -# added by tailscale -[network] # trailing comment -generateResolvConf = false # trailing comment`, - want: map[string]map[string]string{ - "automount": {"enabled": "true", "root": "/mnt/"}, - "network": {"generateResolvConf": "false"}, - }, - }, - } - for _, test := range tests { - got := parseIni(test.src) - if !reflect.DeepEqual(got, test.want) { - t.Errorf("for:\n%s\ngot: %v\nwant: %v", test.src, got, test.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows + +package dns + +import ( + "reflect" + "testing" +) + +func TestParseIni(t *testing.T) { + var tests = []struct { + src string + want map[string]map[string]string + }{ + { + src: `# appended wsl.conf file +[automount] + enabled = true + root=/mnt/ +# added by tailscale +[network] # trailing comment +generateResolvConf = false # trailing comment`, + want: map[string]map[string]string{ + "automount": {"enabled": "true", "root": "/mnt/"}, + "network": {"generateResolvConf": "false"}, + }, + }, + } + for _, test := range tests { + got := parseIni(test.src) + if !reflect.DeepEqual(got, test.want) { + t.Errorf("for:\n%s\ngot: %v\nwant: %v", test.src, got, test.want) + } + } +} diff --git a/net/dns/noop.go b/net/dns/noop.go index c90162668..9466b57a0 100644 --- a/net/dns/noop.go +++ b/net/dns/noop.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dns - -type noopManager struct{} - -func (m noopManager) SetDNS(OSConfig) error { return nil } -func (m noopManager) SupportsSplitDNS() bool { return false } -func (m noopManager) Close() error { return nil } -func (m noopManager) GetBaseConfig() (OSConfig, error) { - return OSConfig{}, ErrGetBaseConfigNotSupported -} - -func NewNoopManager() (noopManager, error) { - return noopManager{}, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dns + +type noopManager struct{} + +func (m noopManager) SetDNS(OSConfig) error { return nil } +func (m noopManager) SupportsSplitDNS() bool { return false } +func (m noopManager) Close() error { return nil } +func (m noopManager) GetBaseConfig() (OSConfig, error) { + return OSConfig{}, ErrGetBaseConfigNotSupported +} + +func NewNoopManager() (noopManager, error) { + return noopManager{}, nil +} diff --git a/net/dns/resolvconf-workaround.sh b/net/dns/resolvconf-workaround.sh index 254b3949b..aec6708a0 100644 --- a/net/dns/resolvconf-workaround.sh +++ b/net/dns/resolvconf-workaround.sh @@ -1,62 +1,62 @@ -#!/bin/sh -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause -# -# This script is a workaround for a vpn-unfriendly behavior of the -# original resolvconf by Thomas Hood. Unlike the `openresolv` -# implementation (whose binary is also called resolvconf, -# confusingly), the original resolvconf lacks a way to specify -# "exclusive mode" for a provider configuration. In practice, this -# means that if Tailscale wants to install a DNS configuration, that -# config will get "blended" with the configs from other sources, -# rather than override those other sources. -# -# This script gets installed at /etc/resolvconf/update-libc.d, which -# is a directory of hook scripts that get run after resolvconf's libc -# helper has finished rewriting /etc/resolv.conf. It's meant to notify -# consumers of resolv.conf of a new configuration. -# -# Instead, we use that hook mechanism to reach into resolvconf's -# stuff, and rewrite the libc-generated resolv.conf to exclusively -# contain Tailscale's configuration - effectively implementing -# exclusive mode ourselves in post-production. - -set -e - -if [ -n "$TAILSCALE_RESOLVCONF_HOOK_LOOP" ]; then - # Hook script being invoked by itself, skip. - exit 0 -fi - -if [ ! -f tun-tailscale.inet ]; then - # Tailscale isn't trying to manage DNS, do nothing. - exit 0 -fi - -if ! grep resolvconf /etc/resolv.conf >/dev/null; then - # resolvconf isn't managing /etc/resolv.conf, do nothing. - exit 0 -fi - -# Write out a modified /etc/resolv.conf containing just our config. -( - if [ -f /etc/resolvconf/resolv.conf.d/head ]; then - cat /etc/resolvconf/resolv.conf.d/head - fi - echo "# Tailscale workaround applied to set exclusive DNS configuration." - cat tun-tailscale.inet - if [ -f /etc/resolvconf/resolv.conf.d/base ]; then - # Keep options and sortlist, discard other base things since - # they're the things we're trying to override. - grep -e 'sortlist ' -e 'options ' /etc/resolvconf/resolv.conf.d/base || true - fi - if [ -f /etc/resolvconf/resolv.conf.d/tail ]; then - cat /etc/resolvconf/resolv.conf.d/tail - fi -) >/etc/resolv.conf - -if [ -d /etc/resolvconf/update-libc.d ] ; then - # Re-notify libc watchers that we've changed resolv.conf again. - export TAILSCALE_RESOLVCONF_HOOK_LOOP=1 - exec run-parts /etc/resolvconf/update-libc.d -fi +#!/bin/sh +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause +# +# This script is a workaround for a vpn-unfriendly behavior of the +# original resolvconf by Thomas Hood. Unlike the `openresolv` +# implementation (whose binary is also called resolvconf, +# confusingly), the original resolvconf lacks a way to specify +# "exclusive mode" for a provider configuration. In practice, this +# means that if Tailscale wants to install a DNS configuration, that +# config will get "blended" with the configs from other sources, +# rather than override those other sources. +# +# This script gets installed at /etc/resolvconf/update-libc.d, which +# is a directory of hook scripts that get run after resolvconf's libc +# helper has finished rewriting /etc/resolv.conf. It's meant to notify +# consumers of resolv.conf of a new configuration. +# +# Instead, we use that hook mechanism to reach into resolvconf's +# stuff, and rewrite the libc-generated resolv.conf to exclusively +# contain Tailscale's configuration - effectively implementing +# exclusive mode ourselves in post-production. + +set -e + +if [ -n "$TAILSCALE_RESOLVCONF_HOOK_LOOP" ]; then + # Hook script being invoked by itself, skip. + exit 0 +fi + +if [ ! -f tun-tailscale.inet ]; then + # Tailscale isn't trying to manage DNS, do nothing. + exit 0 +fi + +if ! grep resolvconf /etc/resolv.conf >/dev/null; then + # resolvconf isn't managing /etc/resolv.conf, do nothing. + exit 0 +fi + +# Write out a modified /etc/resolv.conf containing just our config. +( + if [ -f /etc/resolvconf/resolv.conf.d/head ]; then + cat /etc/resolvconf/resolv.conf.d/head + fi + echo "# Tailscale workaround applied to set exclusive DNS configuration." + cat tun-tailscale.inet + if [ -f /etc/resolvconf/resolv.conf.d/base ]; then + # Keep options and sortlist, discard other base things since + # they're the things we're trying to override. + grep -e 'sortlist ' -e 'options ' /etc/resolvconf/resolv.conf.d/base || true + fi + if [ -f /etc/resolvconf/resolv.conf.d/tail ]; then + cat /etc/resolvconf/resolv.conf.d/tail + fi +) >/etc/resolv.conf + +if [ -d /etc/resolvconf/update-libc.d ] ; then + # Re-notify libc watchers that we've changed resolv.conf again. + export TAILSCALE_RESOLVCONF_HOOK_LOOP=1 + exec run-parts /etc/resolvconf/update-libc.d +fi diff --git a/net/dns/resolvconf.go b/net/dns/resolvconf.go index 9e2a41c4a..ca584ffcc 100644 --- a/net/dns/resolvconf.go +++ b/net/dns/resolvconf.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux || freebsd || openbsd - -package dns - -import ( - "bytes" - "os/exec" -) - -func resolvconfStyle() string { - if _, err := exec.LookPath("resolvconf"); err != nil { - return "" - } - output, err := exec.Command("resolvconf", "--version").CombinedOutput() - if err != nil { - // Debian resolvconf doesn't understand --version, and - // exits with a specific error code. - if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 99 { - return "debian" - } - } - if bytes.HasPrefix(output, []byte("Debian resolvconf")) { - return "debian" - } - // Treat everything else as openresolv, by far the more popular implementation. - return "openresolv" -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux || freebsd || openbsd + +package dns + +import ( + "bytes" + "os/exec" +) + +func resolvconfStyle() string { + if _, err := exec.LookPath("resolvconf"); err != nil { + return "" + } + output, err := exec.Command("resolvconf", "--version").CombinedOutput() + if err != nil { + // Debian resolvconf doesn't understand --version, and + // exits with a specific error code. + if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 99 { + return "debian" + } + } + if bytes.HasPrefix(output, []byte("Debian resolvconf")) { + return "debian" + } + // Treat everything else as openresolv, by far the more popular implementation. + return "openresolv" +} diff --git a/net/dns/resolvconffile/resolvconffile.go b/net/dns/resolvconffile/resolvconffile.go index 66c1600d8..753000f6d 100644 --- a/net/dns/resolvconffile/resolvconffile.go +++ b/net/dns/resolvconffile/resolvconffile.go @@ -1,124 +1,124 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package resolvconffile parses & serializes /etc/resolv.conf-style files. -// -// It's a leaf package so both net/dns and net/dns/resolver can depend -// on it and we can unify a handful of implementations. -// -// The package is verbosely named to disambiguate it from resolvconf -// the daemon, which Tailscale also supports. -package resolvconffile - -import ( - "bufio" - "bytes" - "fmt" - "io" - "net/netip" - "os" - "strings" - - "tailscale.com/util/dnsname" -) - -// Path is the canonical location of resolv.conf. -const Path = "/etc/resolv.conf" - -// Config represents a resolv.conf(5) file. -type Config struct { - // Nameservers are the IP addresses of the nameservers to use. - Nameservers []netip.Addr - - // SearchDomains are the domain suffixes to use when expanding - // single-label name queries. SearchDomains is additive to - // whatever non-Tailscale search domains the OS has. - SearchDomains []dnsname.FQDN -} - -// Write writes c to w. It does so in one Write call. -func (c *Config) Write(w io.Writer) error { - buf := new(bytes.Buffer) - io.WriteString(buf, "# resolv.conf(5) file generated by tailscale\n") - io.WriteString(buf, "# For more info, see https://tailscale.com/s/resolvconf-overwrite\n") - io.WriteString(buf, "# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN\n\n") - for _, ns := range c.Nameservers { - io.WriteString(buf, "nameserver ") - io.WriteString(buf, ns.String()) - io.WriteString(buf, "\n") - } - if len(c.SearchDomains) > 0 { - io.WriteString(buf, "search") - for _, domain := range c.SearchDomains { - io.WriteString(buf, " ") - io.WriteString(buf, domain.WithoutTrailingDot()) - } - io.WriteString(buf, "\n") - } - _, err := w.Write(buf.Bytes()) - return err -} - -// Parse parses a resolv.conf file from r. -func Parse(r io.Reader) (*Config, error) { - config := new(Config) - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := scanner.Text() - line, _, _ = strings.Cut(line, "#") // remove any comments - line = strings.TrimSpace(line) - - if s, ok := strings.CutPrefix(line, "nameserver"); ok { - nameserver := strings.TrimSpace(s) - if len(nameserver) == len(s) { - return nil, fmt.Errorf("missing space after \"nameserver\" in %q", line) - } - ip, err := netip.ParseAddr(nameserver) - if err != nil { - return nil, err - } - config.Nameservers = append(config.Nameservers, ip) - continue - } - - if s, ok := strings.CutPrefix(line, "search"); ok { - domains := strings.TrimSpace(s) - if len(domains) == len(s) { - // No leading space?! - return nil, fmt.Errorf("missing space after \"search\" in %q", line) - } - for len(domains) > 0 { - domain := domains - i := strings.IndexAny(domain, " \t") - if i != -1 { - domain = domain[:i] - domains = strings.TrimSpace(domains[i+1:]) - } else { - domains = "" - } - fqdn, err := dnsname.ToFQDN(domain) - if err != nil { - return nil, fmt.Errorf("parsing search domain %q in %q: %w", domain, line, err) - } - config.SearchDomains = append(config.SearchDomains, fqdn) - } - } - } - return config, nil -} - -// ParseFile parses the named resolv.conf file. -func ParseFile(name string) (*Config, error) { - fi, err := os.Stat(name) - if err != nil { - return nil, err - } - if n := fi.Size(); n > 10<<10 { - return nil, fmt.Errorf("unexpectedly large %q file: %d bytes", name, n) - } - all, err := os.ReadFile(name) - if err != nil { - return nil, err - } - return Parse(bytes.NewReader(all)) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package resolvconffile parses & serializes /etc/resolv.conf-style files. +// +// It's a leaf package so both net/dns and net/dns/resolver can depend +// on it and we can unify a handful of implementations. +// +// The package is verbosely named to disambiguate it from resolvconf +// the daemon, which Tailscale also supports. +package resolvconffile + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net/netip" + "os" + "strings" + + "tailscale.com/util/dnsname" +) + +// Path is the canonical location of resolv.conf. +const Path = "/etc/resolv.conf" + +// Config represents a resolv.conf(5) file. +type Config struct { + // Nameservers are the IP addresses of the nameservers to use. + Nameservers []netip.Addr + + // SearchDomains are the domain suffixes to use when expanding + // single-label name queries. SearchDomains is additive to + // whatever non-Tailscale search domains the OS has. + SearchDomains []dnsname.FQDN +} + +// Write writes c to w. It does so in one Write call. +func (c *Config) Write(w io.Writer) error { + buf := new(bytes.Buffer) + io.WriteString(buf, "# resolv.conf(5) file generated by tailscale\n") + io.WriteString(buf, "# For more info, see https://tailscale.com/s/resolvconf-overwrite\n") + io.WriteString(buf, "# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN\n\n") + for _, ns := range c.Nameservers { + io.WriteString(buf, "nameserver ") + io.WriteString(buf, ns.String()) + io.WriteString(buf, "\n") + } + if len(c.SearchDomains) > 0 { + io.WriteString(buf, "search") + for _, domain := range c.SearchDomains { + io.WriteString(buf, " ") + io.WriteString(buf, domain.WithoutTrailingDot()) + } + io.WriteString(buf, "\n") + } + _, err := w.Write(buf.Bytes()) + return err +} + +// Parse parses a resolv.conf file from r. +func Parse(r io.Reader) (*Config, error) { + config := new(Config) + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + line, _, _ = strings.Cut(line, "#") // remove any comments + line = strings.TrimSpace(line) + + if s, ok := strings.CutPrefix(line, "nameserver"); ok { + nameserver := strings.TrimSpace(s) + if len(nameserver) == len(s) { + return nil, fmt.Errorf("missing space after \"nameserver\" in %q", line) + } + ip, err := netip.ParseAddr(nameserver) + if err != nil { + return nil, err + } + config.Nameservers = append(config.Nameservers, ip) + continue + } + + if s, ok := strings.CutPrefix(line, "search"); ok { + domains := strings.TrimSpace(s) + if len(domains) == len(s) { + // No leading space?! + return nil, fmt.Errorf("missing space after \"search\" in %q", line) + } + for len(domains) > 0 { + domain := domains + i := strings.IndexAny(domain, " \t") + if i != -1 { + domain = domain[:i] + domains = strings.TrimSpace(domains[i+1:]) + } else { + domains = "" + } + fqdn, err := dnsname.ToFQDN(domain) + if err != nil { + return nil, fmt.Errorf("parsing search domain %q in %q: %w", domain, line, err) + } + config.SearchDomains = append(config.SearchDomains, fqdn) + } + } + } + return config, nil +} + +// ParseFile parses the named resolv.conf file. +func ParseFile(name string) (*Config, error) { + fi, err := os.Stat(name) + if err != nil { + return nil, err + } + if n := fi.Size(); n > 10<<10 { + return nil, fmt.Errorf("unexpectedly large %q file: %d bytes", name, n) + } + all, err := os.ReadFile(name) + if err != nil { + return nil, err + } + return Parse(bytes.NewReader(all)) +} diff --git a/net/dns/resolvconfpath_default.go b/net/dns/resolvconfpath_default.go index 02f24a0cf..57e82c4c7 100644 --- a/net/dns/resolvconfpath_default.go +++ b/net/dns/resolvconfpath_default.go @@ -1,11 +1,11 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !gokrazy - -package dns - -const ( - resolvConf = "/etc/resolv.conf" - backupConf = "/etc/resolv.pre-tailscale-backup.conf" -) +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !gokrazy + +package dns + +const ( + resolvConf = "/etc/resolv.conf" + backupConf = "/etc/resolv.pre-tailscale-backup.conf" +) diff --git a/net/dns/resolvconfpath_gokrazy.go b/net/dns/resolvconfpath_gokrazy.go index 6315596d2..f0759b0e3 100644 --- a/net/dns/resolvconfpath_gokrazy.go +++ b/net/dns/resolvconfpath_gokrazy.go @@ -1,11 +1,11 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build gokrazy - -package dns - -const ( - resolvConf = "/tmp/resolv.conf" - backupConf = "/tmp/resolv.pre-tailscale-backup.conf" -) +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build gokrazy + +package dns + +const ( + resolvConf = "/tmp/resolv.conf" + backupConf = "/tmp/resolv.pre-tailscale-backup.conf" +) diff --git a/net/dns/resolver/doh_test.go b/net/dns/resolver/doh_test.go index d9ef970c2..a9c284761 100644 --- a/net/dns/resolver/doh_test.go +++ b/net/dns/resolver/doh_test.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package resolver - -import ( - "context" - "flag" - "net/http" - "testing" - - "golang.org/x/net/dns/dnsmessage" - "tailscale.com/net/dns/publicdns" -) - -var testDoH = flag.Bool("test-doh", false, "do real DoH tests against the network") - -const someDNSID = 123 // something non-zero as a test; in violation of spec's SHOULD of 0 - -func someDNSQuestion(t testing.TB) []byte { - b := dnsmessage.NewBuilder(nil, dnsmessage.Header{ - OpCode: 0, // query - RecursionDesired: true, - ID: someDNSID, - }) - b.StartQuestions() // err - b.Question(dnsmessage.Question{ - Name: dnsmessage.MustNewName("tailscale.com."), - Type: dnsmessage.TypeA, - Class: dnsmessage.ClassINET, - }) - msg, err := b.Finish() - if err != nil { - t.Fatal(err) - } - return msg -} - -func TestDoH(t *testing.T) { - if !*testDoH { - t.Skip("skipping manual test without --test-doh flag") - } - prefixes := publicdns.KnownDoHPrefixes() - if len(prefixes) == 0 { - t.Fatal("no known DoH") - } - - f := &forwarder{} - - for _, urlBase := range prefixes { - t.Run(urlBase, func(t *testing.T) { - c, ok := f.getKnownDoHClientForProvider(urlBase) - if !ok { - t.Fatal("expected DoH") - } - res, err := f.sendDoH(context.Background(), urlBase, c, someDNSQuestion(t)) - if err != nil { - t.Fatal(err) - } - c.Transport.(*http.Transport).CloseIdleConnections() - - var p dnsmessage.Parser - h, err := p.Start(res) - if err != nil { - t.Fatal(err) - } - if h.ID != someDNSID { - t.Errorf("response DNS ID = %v; want %v", h.ID, someDNSID) - } - - p.SkipAllQuestions() - aa, err := p.AllAnswers() - if err != nil { - t.Fatal(err) - } - if len(aa) == 0 { - t.Fatal("no answers") - } - for _, r := range aa { - t.Logf("got: %v", r.GoString()) - } - }) - } -} - -func TestDoHV6Fallback(t *testing.T) { - for _, base := range publicdns.KnownDoHPrefixes() { - for _, ip := range publicdns.DoHIPsOfBase(base) { - if ip.Is4() { - ip6, ok := publicdns.DoHV6(base) - if !ok { - t.Errorf("no v6 DoH known for %v", ip) - } else if !ip6.Is6() { - t.Errorf("dohV6(%q) returned non-v6 address %v", base, ip6) - } - } - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package resolver + +import ( + "context" + "flag" + "net/http" + "testing" + + "golang.org/x/net/dns/dnsmessage" + "tailscale.com/net/dns/publicdns" +) + +var testDoH = flag.Bool("test-doh", false, "do real DoH tests against the network") + +const someDNSID = 123 // something non-zero as a test; in violation of spec's SHOULD of 0 + +func someDNSQuestion(t testing.TB) []byte { + b := dnsmessage.NewBuilder(nil, dnsmessage.Header{ + OpCode: 0, // query + RecursionDesired: true, + ID: someDNSID, + }) + b.StartQuestions() // err + b.Question(dnsmessage.Question{ + Name: dnsmessage.MustNewName("tailscale.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }) + msg, err := b.Finish() + if err != nil { + t.Fatal(err) + } + return msg +} + +func TestDoH(t *testing.T) { + if !*testDoH { + t.Skip("skipping manual test without --test-doh flag") + } + prefixes := publicdns.KnownDoHPrefixes() + if len(prefixes) == 0 { + t.Fatal("no known DoH") + } + + f := &forwarder{} + + for _, urlBase := range prefixes { + t.Run(urlBase, func(t *testing.T) { + c, ok := f.getKnownDoHClientForProvider(urlBase) + if !ok { + t.Fatal("expected DoH") + } + res, err := f.sendDoH(context.Background(), urlBase, c, someDNSQuestion(t)) + if err != nil { + t.Fatal(err) + } + c.Transport.(*http.Transport).CloseIdleConnections() + + var p dnsmessage.Parser + h, err := p.Start(res) + if err != nil { + t.Fatal(err) + } + if h.ID != someDNSID { + t.Errorf("response DNS ID = %v; want %v", h.ID, someDNSID) + } + + p.SkipAllQuestions() + aa, err := p.AllAnswers() + if err != nil { + t.Fatal(err) + } + if len(aa) == 0 { + t.Fatal("no answers") + } + for _, r := range aa { + t.Logf("got: %v", r.GoString()) + } + }) + } +} + +func TestDoHV6Fallback(t *testing.T) { + for _, base := range publicdns.KnownDoHPrefixes() { + for _, ip := range publicdns.DoHIPsOfBase(base) { + if ip.Is4() { + ip6, ok := publicdns.DoHV6(base) + if !ok { + t.Errorf("no v6 DoH known for %v", ip) + } else if !ip6.Is6() { + t.Errorf("dohV6(%q) returned non-v6 address %v", base, ip6) + } + } + } + } +} diff --git a/net/dns/resolver/macios_ext.go b/net/dns/resolver/macios_ext.go index 37cccc7f0..e3f979c19 100644 --- a/net/dns/resolver/macios_ext.go +++ b/net/dns/resolver/macios_ext.go @@ -1,26 +1,26 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ts_macext && (darwin || ios) - -package resolver - -import ( - "errors" - "net" - - "tailscale.com/net/netmon" - "tailscale.com/net/netns" -) - -func init() { - initListenConfig = initListenConfigNetworkExtension -} - -func initListenConfigNetworkExtension(nc *net.ListenConfig, netMon *netmon.Monitor, tunName string) error { - nif, ok := netMon.InterfaceState().Interface[tunName] - if !ok { - return errors.New("utun not found") - } - return netns.SetListenConfigInterfaceIndex(nc, nif.Interface.Index) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_macext && (darwin || ios) + +package resolver + +import ( + "errors" + "net" + + "tailscale.com/net/netmon" + "tailscale.com/net/netns" +) + +func init() { + initListenConfig = initListenConfigNetworkExtension +} + +func initListenConfigNetworkExtension(nc *net.ListenConfig, netMon *netmon.Monitor, tunName string) error { + nif, ok := netMon.InterfaceState().Interface[tunName] + if !ok { + return errors.New("utun not found") + } + return netns.SetListenConfigInterfaceIndex(nc, nif.Interface.Index) +} diff --git a/net/dns/resolver/tsdns_server_test.go b/net/dns/resolver/tsdns_server_test.go index be47cdfbc..82fd3bebf 100644 --- a/net/dns/resolver/tsdns_server_test.go +++ b/net/dns/resolver/tsdns_server_test.go @@ -1,333 +1,333 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package resolver - -import ( - "fmt" - "net" - "net/netip" - "strings" - "testing" - - "github.com/miekg/dns" -) - -// This file exists to isolate the test infrastructure -// that depends on github.com/miekg/dns -// from the rest, which only depends on dnsmessage. - -// resolveToIP returns a handler function which responds -// to queries of type A it receives with an A record containing ipv4, -// to queries of type AAAA with an AAAA record containing ipv6, -// to queries of type NS with an NS record containing name. -func resolveToIP(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - - if len(req.Question) != 1 { - panic("not a single-question request") - } - question := req.Question[0] - - var ans dns.RR - switch question.Qtype { - case dns.TypeA: - ans = &dns.A{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - }, - A: ipv4.AsSlice(), - } - case dns.TypeAAAA: - ans = &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - }, - AAAA: ipv6.AsSlice(), - } - case dns.TypeNS: - ans = &dns.NS{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - }, - Ns: ns, - } - } - - m.Answer = append(m.Answer, ans) - w.WriteMsg(m) - } -} - -// resolveToIPLowercase returns a handler function which canonicalizes responses -// by lowercasing the question and answer names, and responds -// to queries of type A it receives with an A record containing ipv4, -// to queries of type AAAA with an AAAA record containing ipv6, -// to queries of type NS with an NS record containing name. -func resolveToIPLowercase(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - - if len(req.Question) != 1 { - panic("not a single-question request") - } - m.Question[0].Name = strings.ToLower(m.Question[0].Name) - question := req.Question[0] - - var ans dns.RR - switch question.Qtype { - case dns.TypeA: - ans = &dns.A{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - }, - A: ipv4.AsSlice(), - } - case dns.TypeAAAA: - ans = &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - }, - AAAA: ipv6.AsSlice(), - } - case dns.TypeNS: - ans = &dns.NS{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - }, - Ns: ns, - } - } - - m.Answer = append(m.Answer, ans) - w.WriteMsg(m) - } -} - -// resolveToTXT returns a handler function which responds to queries of type TXT -// it receives with the strings in txts. -func resolveToTXT(txts []string, ednsMaxSize uint16) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - - if len(req.Question) != 1 { - panic("not a single-question request") - } - question := req.Question[0] - - if question.Qtype != dns.TypeTXT { - w.WriteMsg(m) - return - } - - ans := &dns.TXT{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeTXT, - Class: dns.ClassINET, - }, - Txt: txts, - } - - m.Answer = append(m.Answer, ans) - - queryInfo := &dns.TXT{ - Hdr: dns.RR_Header{ - Name: "query-info.test.", - Rrtype: dns.TypeTXT, - Class: dns.ClassINET, - }, - } - - if edns := req.IsEdns0(); edns == nil { - queryInfo.Txt = []string{"EDNS=false"} - } else { - queryInfo.Txt = []string{"EDNS=true", fmt.Sprintf("maxSize=%v", edns.UDPSize())} - } - - m.Extra = append(m.Extra, queryInfo) - - if ednsMaxSize > 0 { - m.SetEdns0(ednsMaxSize, false) - } - - if err := w.WriteMsg(m); err != nil { - panic(err) - } - } -} - -var resolveToNXDOMAIN = dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetRcode(req, dns.RcodeNameError) - w.WriteMsg(m) -}) - -// weirdoGoCNAMEHandler returns a DNS handler that satisfies -// Go's weird Resolver.LookupCNAME (read its godoc carefully!). -// -// This doesn't even return a CNAME record, because that's not -// what Go looks for. -func weirdoGoCNAMEHandler(target string) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - question := req.Question[0] - - switch question.Qtype { - case dns.TypeA: - m.Answer = append(m.Answer, &dns.CNAME{ - Hdr: dns.RR_Header{ - Name: target, - Rrtype: dns.TypeCNAME, - Class: dns.ClassINET, - Ttl: 600, - }, - Target: target, - }) - case dns.TypeAAAA: - m.Answer = append(m.Answer, &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: target, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: 600, - }, - AAAA: net.ParseIP("1::2"), - }) - } - w.WriteMsg(m) - } -} - -// dnsHandler returns a handler that replies with the answers/options -// provided. -// -// Types supported: netip.Addr. -func dnsHandler(answers ...any) dns.HandlerFunc { - return func(w dns.ResponseWriter, req *dns.Msg) { - m := new(dns.Msg) - m.SetReply(req) - if len(req.Question) != 1 { - panic("not a single-question request") - } - m.RecursionAvailable = true // to stop net package's errLameReferral on empty replies - - question := req.Question[0] - for _, a := range answers { - switch a := a.(type) { - default: - panic(fmt.Sprintf("unsupported dnsHandler arg %T", a)) - case netip.Addr: - ip := a - if ip.Is4() { - m.Answer = append(m.Answer, &dns.A{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - }, - A: ip.AsSlice(), - }) - } else if ip.Is6() { - m.Answer = append(m.Answer, &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - }, - AAAA: ip.AsSlice(), - }) - } - case dns.PTR: - ptr := a - ptr.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypePTR, - Class: dns.ClassINET, - } - m.Answer = append(m.Answer, &ptr) - case dns.CNAME: - c := a - c.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeCNAME, - Class: dns.ClassINET, - Ttl: 600, - } - m.Answer = append(m.Answer, &c) - case dns.TXT: - txt := a - txt.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeTXT, - Class: dns.ClassINET, - } - m.Answer = append(m.Answer, &txt) - case dns.SRV: - srv := a - srv.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeSRV, - Class: dns.ClassINET, - } - m.Answer = append(m.Answer, &srv) - case dns.NS: - rr := a - rr.Hdr = dns.RR_Header{ - Name: question.Name, - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - } - m.Answer = append(m.Answer, &rr) - } - } - w.WriteMsg(m) - } -} - -func serveDNS(tb testing.TB, addr string, records ...any) *dns.Server { - if len(records)%2 != 0 { - panic("must have an even number of record values") - } - mux := dns.NewServeMux() - for i := 0; i < len(records); i += 2 { - name := records[i].(string) - handler := records[i+1].(dns.Handler) - mux.Handle(name, handler) - } - waitch := make(chan struct{}) - server := &dns.Server{ - Addr: addr, - Net: "udp", - Handler: mux, - NotifyStartedFunc: func() { close(waitch) }, - ReusePort: true, - } - - go func() { - err := server.ListenAndServe() - if err != nil { - panic(fmt.Sprintf("ListenAndServe(%q): %v", addr, err)) - } - }() - - <-waitch - return server -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package resolver + +import ( + "fmt" + "net" + "net/netip" + "strings" + "testing" + + "github.com/miekg/dns" +) + +// This file exists to isolate the test infrastructure +// that depends on github.com/miekg/dns +// from the rest, which only depends on dnsmessage. + +// resolveToIP returns a handler function which responds +// to queries of type A it receives with an A record containing ipv4, +// to queries of type AAAA with an AAAA record containing ipv6, +// to queries of type NS with an NS record containing name. +func resolveToIP(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + + if len(req.Question) != 1 { + panic("not a single-question request") + } + question := req.Question[0] + + var ans dns.RR + switch question.Qtype { + case dns.TypeA: + ans = &dns.A{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: ipv4.AsSlice(), + } + case dns.TypeAAAA: + ans = &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + }, + AAAA: ipv6.AsSlice(), + } + case dns.TypeNS: + ans = &dns.NS{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + }, + Ns: ns, + } + } + + m.Answer = append(m.Answer, ans) + w.WriteMsg(m) + } +} + +// resolveToIPLowercase returns a handler function which canonicalizes responses +// by lowercasing the question and answer names, and responds +// to queries of type A it receives with an A record containing ipv4, +// to queries of type AAAA with an AAAA record containing ipv6, +// to queries of type NS with an NS record containing name. +func resolveToIPLowercase(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + + if len(req.Question) != 1 { + panic("not a single-question request") + } + m.Question[0].Name = strings.ToLower(m.Question[0].Name) + question := req.Question[0] + + var ans dns.RR + switch question.Qtype { + case dns.TypeA: + ans = &dns.A{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: ipv4.AsSlice(), + } + case dns.TypeAAAA: + ans = &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + }, + AAAA: ipv6.AsSlice(), + } + case dns.TypeNS: + ans = &dns.NS{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + }, + Ns: ns, + } + } + + m.Answer = append(m.Answer, ans) + w.WriteMsg(m) + } +} + +// resolveToTXT returns a handler function which responds to queries of type TXT +// it receives with the strings in txts. +func resolveToTXT(txts []string, ednsMaxSize uint16) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + + if len(req.Question) != 1 { + panic("not a single-question request") + } + question := req.Question[0] + + if question.Qtype != dns.TypeTXT { + w.WriteMsg(m) + return + } + + ans := &dns.TXT{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + }, + Txt: txts, + } + + m.Answer = append(m.Answer, ans) + + queryInfo := &dns.TXT{ + Hdr: dns.RR_Header{ + Name: "query-info.test.", + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + }, + } + + if edns := req.IsEdns0(); edns == nil { + queryInfo.Txt = []string{"EDNS=false"} + } else { + queryInfo.Txt = []string{"EDNS=true", fmt.Sprintf("maxSize=%v", edns.UDPSize())} + } + + m.Extra = append(m.Extra, queryInfo) + + if ednsMaxSize > 0 { + m.SetEdns0(ednsMaxSize, false) + } + + if err := w.WriteMsg(m); err != nil { + panic(err) + } + } +} + +var resolveToNXDOMAIN = dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(req, dns.RcodeNameError) + w.WriteMsg(m) +}) + +// weirdoGoCNAMEHandler returns a DNS handler that satisfies +// Go's weird Resolver.LookupCNAME (read its godoc carefully!). +// +// This doesn't even return a CNAME record, because that's not +// what Go looks for. +func weirdoGoCNAMEHandler(target string) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + question := req.Question[0] + + switch question.Qtype { + case dns.TypeA: + m.Answer = append(m.Answer, &dns.CNAME{ + Hdr: dns.RR_Header{ + Name: target, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: 600, + }, + Target: target, + }) + case dns.TypeAAAA: + m.Answer = append(m.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: target, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 600, + }, + AAAA: net.ParseIP("1::2"), + }) + } + w.WriteMsg(m) + } +} + +// dnsHandler returns a handler that replies with the answers/options +// provided. +// +// Types supported: netip.Addr. +func dnsHandler(answers ...any) dns.HandlerFunc { + return func(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + if len(req.Question) != 1 { + panic("not a single-question request") + } + m.RecursionAvailable = true // to stop net package's errLameReferral on empty replies + + question := req.Question[0] + for _, a := range answers { + switch a := a.(type) { + default: + panic(fmt.Sprintf("unsupported dnsHandler arg %T", a)) + case netip.Addr: + ip := a + if ip.Is4() { + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: ip.AsSlice(), + }) + } else if ip.Is6() { + m.Answer = append(m.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + }, + AAAA: ip.AsSlice(), + }) + } + case dns.PTR: + ptr := a + ptr.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + } + m.Answer = append(m.Answer, &ptr) + case dns.CNAME: + c := a + c.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: 600, + } + m.Answer = append(m.Answer, &c) + case dns.TXT: + txt := a + txt.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + } + m.Answer = append(m.Answer, &txt) + case dns.SRV: + srv := a + srv.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeSRV, + Class: dns.ClassINET, + } + m.Answer = append(m.Answer, &srv) + case dns.NS: + rr := a + rr.Hdr = dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeNS, + Class: dns.ClassINET, + } + m.Answer = append(m.Answer, &rr) + } + } + w.WriteMsg(m) + } +} + +func serveDNS(tb testing.TB, addr string, records ...any) *dns.Server { + if len(records)%2 != 0 { + panic("must have an even number of record values") + } + mux := dns.NewServeMux() + for i := 0; i < len(records); i += 2 { + name := records[i].(string) + handler := records[i+1].(dns.Handler) + mux.Handle(name, handler) + } + waitch := make(chan struct{}) + server := &dns.Server{ + Addr: addr, + Net: "udp", + Handler: mux, + NotifyStartedFunc: func() { close(waitch) }, + ReusePort: true, + } + + go func() { + err := server.ListenAndServe() + if err != nil { + panic(fmt.Sprintf("ListenAndServe(%q): %v", addr, err)) + } + }() + + <-waitch + return server +} diff --git a/net/dns/utf.go b/net/dns/utf.go index 267829c05..0c1db69ac 100644 --- a/net/dns/utf.go +++ b/net/dns/utf.go @@ -1,55 +1,55 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dns - -// This code is only used in Windows builds, but is in an -// OS-independent file so tests can run all the time. - -import ( - "bytes" - "encoding/binary" - "unicode/utf16" -) - -// maybeUnUTF16 tries to detect whether bs contains UTF-16, and if so -// translates it to regular UTF-8. -// -// Some of wsl.exe's output get printed as UTF-16, which breaks a -// bunch of things. Try to detect this by looking for a zero byte in -// the first few bytes of output (which will appear if any of those -// codepoints are basic ASCII - very likely). From that we can infer -// that UTF-16 is being printed, and the byte order in use, and we -// decode that back to UTF-8. -// -// https://github.com/microsoft/WSL/issues/4607 -func maybeUnUTF16(bs []byte) []byte { - if len(bs)%2 != 0 { - // Can't be complete UTF-16. - return bs - } - checkLen := 20 - if len(bs) < checkLen { - checkLen = len(bs) - } - zeroOff := bytes.IndexByte(bs[:checkLen], 0) - if zeroOff == -1 { - return bs - } - - // We assume wsl.exe is trying to print an ASCII codepoint, - // meaning the zero byte is in the upper 8 bits of the - // codepoint. That means we can use the zero's byte offset to - // work out if we're seeing little-endian or big-endian - // UTF-16. - var endian binary.ByteOrder = binary.LittleEndian - if zeroOff%2 == 0 { - endian = binary.BigEndian - } - - var u16 []uint16 - for i := 0; i < len(bs); i += 2 { - u16 = append(u16, endian.Uint16(bs[i:])) - } - return []byte(string(utf16.Decode(u16))) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dns + +// This code is only used in Windows builds, but is in an +// OS-independent file so tests can run all the time. + +import ( + "bytes" + "encoding/binary" + "unicode/utf16" +) + +// maybeUnUTF16 tries to detect whether bs contains UTF-16, and if so +// translates it to regular UTF-8. +// +// Some of wsl.exe's output get printed as UTF-16, which breaks a +// bunch of things. Try to detect this by looking for a zero byte in +// the first few bytes of output (which will appear if any of those +// codepoints are basic ASCII - very likely). From that we can infer +// that UTF-16 is being printed, and the byte order in use, and we +// decode that back to UTF-8. +// +// https://github.com/microsoft/WSL/issues/4607 +func maybeUnUTF16(bs []byte) []byte { + if len(bs)%2 != 0 { + // Can't be complete UTF-16. + return bs + } + checkLen := 20 + if len(bs) < checkLen { + checkLen = len(bs) + } + zeroOff := bytes.IndexByte(bs[:checkLen], 0) + if zeroOff == -1 { + return bs + } + + // We assume wsl.exe is trying to print an ASCII codepoint, + // meaning the zero byte is in the upper 8 bits of the + // codepoint. That means we can use the zero's byte offset to + // work out if we're seeing little-endian or big-endian + // UTF-16. + var endian binary.ByteOrder = binary.LittleEndian + if zeroOff%2 == 0 { + endian = binary.BigEndian + } + + var u16 []uint16 + for i := 0; i < len(bs); i += 2 { + u16 = append(u16, endian.Uint16(bs[i:])) + } + return []byte(string(utf16.Decode(u16))) +} diff --git a/net/dns/utf_test.go b/net/dns/utf_test.go index fcf593497..b5fd37262 100644 --- a/net/dns/utf_test.go +++ b/net/dns/utf_test.go @@ -1,24 +1,24 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dns - -import "testing" - -func TestMaybeUnUTF16(t *testing.T) { - tests := []struct { - in string - want string - }{ - {"abc", "abc"}, // UTF-8 - {"a\x00b\x00c\x00", "abc"}, // UTF-16-LE - {"\x00a\x00b\x00c", "abc"}, // UTF-16-BE - } - - for _, test := range tests { - got := string(maybeUnUTF16([]byte(test.in))) - if got != test.want { - t.Errorf("maybeUnUTF16(%q) = %q, want %q", test.in, got, test.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dns + +import "testing" + +func TestMaybeUnUTF16(t *testing.T) { + tests := []struct { + in string + want string + }{ + {"abc", "abc"}, // UTF-8 + {"a\x00b\x00c\x00", "abc"}, // UTF-16-LE + {"\x00a\x00b\x00c", "abc"}, // UTF-16-BE + } + + for _, test := range tests { + got := string(maybeUnUTF16([]byte(test.in))) + if got != test.want { + t.Errorf("maybeUnUTF16(%q) = %q, want %q", test.in, got, test.want) + } + } +} diff --git a/net/dnscache/dnscache_test.go b/net/dnscache/dnscache_test.go index 6a4b96931..ef4249b74 100644 --- a/net/dnscache/dnscache_test.go +++ b/net/dnscache/dnscache_test.go @@ -1,242 +1,242 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dnscache - -import ( - "context" - "errors" - "flag" - "fmt" - "net" - "net/netip" - "reflect" - "testing" - "time" - - "tailscale.com/tstest" -) - -var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial") - -func TestDialer(t *testing.T) { - if *dialTest == "" { - t.Skip("skipping; --dial-test is blank") - } - r := &Resolver{Logf: t.Logf} - var std net.Dialer - dialer := Dialer(std.DialContext, r) - t0 := time.Now() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - c, err := dialer(ctx, "tcp", *dialTest) - if err != nil { - t.Fatal(err) - } - t.Logf("dialed in %v", time.Since(t0)) - c.Close() -} - -func TestDialCall_DNSWasTrustworthy(t *testing.T) { - type step struct { - ip netip.Addr // IP we pretended to dial - err error // the dial error or nil for success - } - mustIP := netip.MustParseAddr - errFail := errors.New("some connect failure") - tests := []struct { - name string - steps []step - want bool - }{ - { - name: "no-info", - want: false, - }, - { - name: "previous-dial", - steps: []step{ - {mustIP("2003::1"), nil}, - {mustIP("2003::1"), errFail}, - }, - want: true, - }, - { - name: "no-previous-dial", - steps: []step{ - {mustIP("2003::1"), errFail}, - }, - want: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - d := &dialer{ - pastConnect: map[netip.Addr]time.Time{}, - } - dc := &dialCall{ - d: d, - } - for _, st := range tt.steps { - dc.noteDialResult(st.ip, st.err) - } - got := dc.dnsWasTrustworthy() - if got != tt.want { - t.Errorf("got %v; want %v", got, tt.want) - } - }) - } -} - -func TestDialCall_uniqueIPs(t *testing.T) { - dc := &dialCall{} - mustIP := netip.MustParseAddr - errFail := errors.New("some connect failure") - dc.noteDialResult(mustIP("2003::1"), errFail) - dc.noteDialResult(mustIP("2003::2"), errFail) - got := dc.uniqueIPs([]netip.Addr{ - mustIP("2003::1"), - mustIP("2003::2"), - mustIP("2003::2"), - mustIP("2003::3"), - mustIP("2003::3"), - mustIP("2003::4"), - mustIP("2003::4"), - }) - want := []netip.Addr{ - mustIP("2003::3"), - mustIP("2003::4"), - } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v; want %v", got, want) - } -} - -func TestResolverAllHostStaticResult(t *testing.T) { - r := &Resolver{ - Logf: t.Logf, - SingleHost: "foo.bar", - SingleHostStaticResult: []netip.Addr{ - netip.MustParseAddr("2001:4860:4860::8888"), - netip.MustParseAddr("2001:4860:4860::8844"), - netip.MustParseAddr("8.8.8.8"), - netip.MustParseAddr("8.8.4.4"), - }, - } - ip4, ip6, allIPs, err := r.LookupIP(context.Background(), "foo.bar") - if err != nil { - t.Fatal(err) - } - if got, want := ip4.String(), "8.8.8.8"; got != want { - t.Errorf("ip4 got %q; want %q", got, want) - } - if got, want := ip6.String(), "2001:4860:4860::8888"; got != want { - t.Errorf("ip4 got %q; want %q", got, want) - } - if got, want := fmt.Sprintf("%q", allIPs), `["2001:4860:4860::8888" "2001:4860:4860::8844" "8.8.8.8" "8.8.4.4"]`; got != want { - t.Errorf("allIPs got %q; want %q", got, want) - } - - _, _, _, err = r.LookupIP(context.Background(), "bad") - if got, want := fmt.Sprint(err), `dnscache: unexpected hostname "bad" doesn't match expected "foo.bar"`; got != want { - t.Errorf("bad dial error got %q; want %q", got, want) - } -} - -func TestShouldTryBootstrap(t *testing.T) { - tstest.Replace(t, &debug, func() bool { return true }) - - type step struct { - ip netip.Addr // IP we pretended to dial - err error // the dial error or nil for success - } - - canceled, cancel := context.WithCancel(context.Background()) - cancel() - - deadlineExceeded, cancel := context.WithTimeout(context.Background(), 0) - defer cancel() - - ctx := context.Background() - errFailed := errors.New("some failure") - - cacheWithFallback := &Resolver{ - Logf: t.Logf, - LookupIPFallback: func(_ context.Context, _ string) ([]netip.Addr, error) { - panic("unimplemented") - }, - } - cacheNoFallback := &Resolver{Logf: t.Logf} - - testCases := []struct { - name string - steps []step - ctx context.Context - err error - noFallback bool - want bool - }{ - { - name: "no-error", - ctx: ctx, - err: nil, - want: false, - }, - { - name: "canceled", - ctx: canceled, - err: errFailed, - want: false, - }, - { - name: "deadline-exceeded", - ctx: deadlineExceeded, - err: errFailed, - want: false, - }, - { - name: "no-fallback", - ctx: ctx, - err: errFailed, - noFallback: true, - want: false, - }, - { - name: "dns-was-trustworthy", - ctx: ctx, - err: errFailed, - steps: []step{ - {netip.MustParseAddr("2003::1"), nil}, - {netip.MustParseAddr("2003::1"), errFailed}, - }, - want: false, - }, - { - name: "should-bootstrap", - ctx: ctx, - err: errFailed, - want: true, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - d := &dialer{ - pastConnect: map[netip.Addr]time.Time{}, - } - if tt.noFallback { - d.dnsCache = cacheNoFallback - } else { - d.dnsCache = cacheWithFallback - } - dc := &dialCall{d: d} - for _, st := range tt.steps { - dc.noteDialResult(st.ip, st.err) - } - got := d.shouldTryBootstrap(tt.ctx, tt.err, dc) - if got != tt.want { - t.Errorf("got %v; want %v", got, tt.want) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dnscache + +import ( + "context" + "errors" + "flag" + "fmt" + "net" + "net/netip" + "reflect" + "testing" + "time" + + "tailscale.com/tstest" +) + +var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial") + +func TestDialer(t *testing.T) { + if *dialTest == "" { + t.Skip("skipping; --dial-test is blank") + } + r := &Resolver{Logf: t.Logf} + var std net.Dialer + dialer := Dialer(std.DialContext, r) + t0 := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + c, err := dialer(ctx, "tcp", *dialTest) + if err != nil { + t.Fatal(err) + } + t.Logf("dialed in %v", time.Since(t0)) + c.Close() +} + +func TestDialCall_DNSWasTrustworthy(t *testing.T) { + type step struct { + ip netip.Addr // IP we pretended to dial + err error // the dial error or nil for success + } + mustIP := netip.MustParseAddr + errFail := errors.New("some connect failure") + tests := []struct { + name string + steps []step + want bool + }{ + { + name: "no-info", + want: false, + }, + { + name: "previous-dial", + steps: []step{ + {mustIP("2003::1"), nil}, + {mustIP("2003::1"), errFail}, + }, + want: true, + }, + { + name: "no-previous-dial", + steps: []step{ + {mustIP("2003::1"), errFail}, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &dialer{ + pastConnect: map[netip.Addr]time.Time{}, + } + dc := &dialCall{ + d: d, + } + for _, st := range tt.steps { + dc.noteDialResult(st.ip, st.err) + } + got := dc.dnsWasTrustworthy() + if got != tt.want { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } +} + +func TestDialCall_uniqueIPs(t *testing.T) { + dc := &dialCall{} + mustIP := netip.MustParseAddr + errFail := errors.New("some connect failure") + dc.noteDialResult(mustIP("2003::1"), errFail) + dc.noteDialResult(mustIP("2003::2"), errFail) + got := dc.uniqueIPs([]netip.Addr{ + mustIP("2003::1"), + mustIP("2003::2"), + mustIP("2003::2"), + mustIP("2003::3"), + mustIP("2003::3"), + mustIP("2003::4"), + mustIP("2003::4"), + }) + want := []netip.Addr{ + mustIP("2003::3"), + mustIP("2003::4"), + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v; want %v", got, want) + } +} + +func TestResolverAllHostStaticResult(t *testing.T) { + r := &Resolver{ + Logf: t.Logf, + SingleHost: "foo.bar", + SingleHostStaticResult: []netip.Addr{ + netip.MustParseAddr("2001:4860:4860::8888"), + netip.MustParseAddr("2001:4860:4860::8844"), + netip.MustParseAddr("8.8.8.8"), + netip.MustParseAddr("8.8.4.4"), + }, + } + ip4, ip6, allIPs, err := r.LookupIP(context.Background(), "foo.bar") + if err != nil { + t.Fatal(err) + } + if got, want := ip4.String(), "8.8.8.8"; got != want { + t.Errorf("ip4 got %q; want %q", got, want) + } + if got, want := ip6.String(), "2001:4860:4860::8888"; got != want { + t.Errorf("ip4 got %q; want %q", got, want) + } + if got, want := fmt.Sprintf("%q", allIPs), `["2001:4860:4860::8888" "2001:4860:4860::8844" "8.8.8.8" "8.8.4.4"]`; got != want { + t.Errorf("allIPs got %q; want %q", got, want) + } + + _, _, _, err = r.LookupIP(context.Background(), "bad") + if got, want := fmt.Sprint(err), `dnscache: unexpected hostname "bad" doesn't match expected "foo.bar"`; got != want { + t.Errorf("bad dial error got %q; want %q", got, want) + } +} + +func TestShouldTryBootstrap(t *testing.T) { + tstest.Replace(t, &debug, func() bool { return true }) + + type step struct { + ip netip.Addr // IP we pretended to dial + err error // the dial error or nil for success + } + + canceled, cancel := context.WithCancel(context.Background()) + cancel() + + deadlineExceeded, cancel := context.WithTimeout(context.Background(), 0) + defer cancel() + + ctx := context.Background() + errFailed := errors.New("some failure") + + cacheWithFallback := &Resolver{ + Logf: t.Logf, + LookupIPFallback: func(_ context.Context, _ string) ([]netip.Addr, error) { + panic("unimplemented") + }, + } + cacheNoFallback := &Resolver{Logf: t.Logf} + + testCases := []struct { + name string + steps []step + ctx context.Context + err error + noFallback bool + want bool + }{ + { + name: "no-error", + ctx: ctx, + err: nil, + want: false, + }, + { + name: "canceled", + ctx: canceled, + err: errFailed, + want: false, + }, + { + name: "deadline-exceeded", + ctx: deadlineExceeded, + err: errFailed, + want: false, + }, + { + name: "no-fallback", + ctx: ctx, + err: errFailed, + noFallback: true, + want: false, + }, + { + name: "dns-was-trustworthy", + ctx: ctx, + err: errFailed, + steps: []step{ + {netip.MustParseAddr("2003::1"), nil}, + {netip.MustParseAddr("2003::1"), errFailed}, + }, + want: false, + }, + { + name: "should-bootstrap", + ctx: ctx, + err: errFailed, + want: true, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + d := &dialer{ + pastConnect: map[netip.Addr]time.Time{}, + } + if tt.noFallback { + d.dnsCache = cacheNoFallback + } else { + d.dnsCache = cacheWithFallback + } + dc := &dialCall{d: d} + for _, st := range tt.steps { + dc.noteDialResult(st.ip, st.err) + } + got := d.shouldTryBootstrap(tt.ctx, tt.err, dc) + if got != tt.want { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } +} diff --git a/net/dnscache/messagecache_test.go b/net/dnscache/messagecache_test.go index 18af32459..41fc33448 100644 --- a/net/dnscache/messagecache_test.go +++ b/net/dnscache/messagecache_test.go @@ -1,291 +1,291 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dnscache - -import ( - "bytes" - "context" - "errors" - "fmt" - "net" - "runtime" - "testing" - "time" - - "golang.org/x/net/dns/dnsmessage" - "tailscale.com/tstest" -) - -func TestMessageCache(t *testing.T) { - clock := tstest.NewClock(tstest.ClockOpts{ - Start: time.Date(1987, 11, 1, 0, 0, 0, 0, time.UTC), - }) - mc := &MessageCache{Clock: clock.Now} - mc.SetMaxCacheSize(2) - clock.Advance(time.Second) - - var out bytes.Buffer - if err := mc.ReplyFromCache(&out, makeQ(1, "foo.com.")); err != ErrCacheMiss { - t.Fatalf("unexpected error: %v", err) - } - - if err := mc.AddCacheEntry( - makeQ(2, "foo.com."), - makeRes(2, "FOO.COM.", ttlOpt(10), - &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}, - &dnsmessage.AResource{A: [4]byte{127, 0, 0, 2}})); err != nil { - t.Fatal(err) - } - - // Expect cache hit, with 10 seconds remaining. - out.Reset() - if err := mc.ReplyFromCache(&out, makeQ(3, "foo.com.")); err != nil { - t.Fatalf("expected cache hit; got: %v", err) - } - if p := mustParseResponse(t, out.Bytes()); p.TxID != 3 { - t.Errorf("TxID = %v; want %v", p.TxID, 3) - } else if p.TTL != 10 { - t.Errorf("TTL = %v; want 10", p.TTL) - } - - // One second elapses, expect a cache hit, with 9 seconds - // remaining. - clock.Advance(time.Second) - out.Reset() - if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.")); err != nil { - t.Fatalf("expected cache hit; got: %v", err) - } - if p := mustParseResponse(t, out.Bytes()); p.TxID != 4 { - t.Errorf("TxID = %v; want %v", p.TxID, 4) - } else if p.TTL != 9 { - t.Errorf("TTL = %v; want 9", p.TTL) - } - - // Expect cache miss on MX record. - if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.TypeMX)); err != ErrCacheMiss { - t.Fatalf("expected cache miss on MX; got: %v", err) - } - // Expect cache miss on CHAOS class. - if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.ClassCHAOS)); err != ErrCacheMiss { - t.Fatalf("expected cache miss on CHAOS; got: %v", err) - } - - // Ten seconds elapses; expect a cache miss. - clock.Advance(10 * time.Second) - if err := mc.ReplyFromCache(&out, makeQ(5, "foo.com.")); err != ErrCacheMiss { - t.Fatalf("expected cache miss, got: %v", err) - } -} - -type parsedMeta struct { - TxID uint16 - TTL uint32 -} - -func mustParseResponse(t testing.TB, r []byte) (ret parsedMeta) { - t.Helper() - var p dnsmessage.Parser - h, err := p.Start(r) - if err != nil { - t.Fatal(err) - } - ret.TxID = h.ID - qq, err := p.AllQuestions() - if err != nil { - t.Fatalf("AllQuestions: %v", err) - } - if len(qq) != 1 { - t.Fatalf("num questions = %v; want 1", len(qq)) - } - aa, err := p.AllAnswers() - if err != nil { - t.Fatalf("AllAnswers: %v", err) - } - for _, r := range aa { - if ret.TTL == 0 { - ret.TTL = r.Header.TTL - } - if ret.TTL != r.Header.TTL { - t.Fatal("mixed TTLs") - } - } - return ret -} - -type responseOpt bool - -type ttlOpt uint32 - -func makeQ(txID uint16, name string, opt ...any) []byte { - opt = append(opt, responseOpt(false)) - return makeDNSPkt(txID, name, opt...) -} - -func makeRes(txID uint16, name string, opt ...any) []byte { - opt = append(opt, responseOpt(true)) - return makeDNSPkt(txID, name, opt...) -} - -func makeDNSPkt(txID uint16, name string, opt ...any) []byte { - typ := dnsmessage.TypeA - class := dnsmessage.ClassINET - var response bool - var answers []dnsmessage.ResourceBody - var ttl uint32 = 1 // one second by default - for _, o := range opt { - switch o := o.(type) { - case dnsmessage.Type: - typ = o - case dnsmessage.Class: - class = o - case responseOpt: - response = bool(o) - case dnsmessage.ResourceBody: - answers = append(answers, o) - case ttlOpt: - ttl = uint32(o) - default: - panic(fmt.Sprintf("unknown opt type %T", o)) - } - } - qname := dnsmessage.MustNewName(name) - msg := dnsmessage.Message{ - Header: dnsmessage.Header{ID: txID, Response: response}, - Questions: []dnsmessage.Question{ - { - Name: qname, - Type: typ, - Class: class, - }, - }, - } - for _, rb := range answers { - msg.Answers = append(msg.Answers, dnsmessage.Resource{ - Header: dnsmessage.ResourceHeader{ - Name: qname, - Type: typ, - Class: class, - TTL: ttl, - }, - Body: rb, - }) - } - buf, err := msg.Pack() - if err != nil { - panic(err) - } - return buf -} - -func TestASCIILowerName(t *testing.T) { - n := asciiLowerName(dnsmessage.MustNewName("Foo.COM.")) - if got, want := n.String(), "foo.com."; got != want { - t.Errorf("got = %q; want %q", got, want) - } -} - -func TestGetDNSQueryCacheKey(t *testing.T) { - tests := []struct { - name string - pkt []byte - want msgQ - txID uint16 - anyTX bool - }{ - { - name: "empty", - }, - { - name: "a", - pkt: makeQ(123, "foo.com."), - want: msgQ{"foo.com.", dnsmessage.TypeA}, - txID: 123, - }, - { - name: "aaaa", - pkt: makeQ(6, "foo.com.", dnsmessage.TypeAAAA), - want: msgQ{"foo.com.", dnsmessage.TypeAAAA}, - txID: 6, - }, - { - name: "normalize_case", - pkt: makeQ(123, "FoO.CoM."), - want: msgQ{"foo.com.", dnsmessage.TypeA}, - txID: 123, - }, - { - name: "ignore_response", - pkt: makeRes(123, "foo.com."), - }, - { - name: "ignore_question_with_answers", - pkt: makeQ(2, "foo.com.", &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}), - }, - { - name: "whatever_go_generates", // in case Go's net package grows functionality we don't handle - pkt: getGoNetPacketDNSQuery("from-go.foo."), - want: msgQ{"from-go.foo.", dnsmessage.TypeA}, - anyTX: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, gotTX, ok := getDNSQueryCacheKey(tt.pkt) - if !ok { - if tt.txID == 0 && got == (msgQ{}) { - return - } - t.Fatal("failed") - } - if got != tt.want { - t.Errorf("got %+v, want %+v", got, tt.want) - } - if gotTX != tt.txID && !tt.anyTX { - t.Errorf("got tx %v, want %v", gotTX, tt.txID) - } - }) - } -} - -func getGoNetPacketDNSQuery(name string) []byte { - if runtime.GOOS == "windows" { - // On Windows, Go's net.Resolver doesn't use the DNS client. - // See https://github.com/golang/go/issues/33097 which - // was approved but not yet implemented. - // For now just pretend it's implemented to make this test - // pass on Windows with complicated the caller. - return makeQ(123, name) - } - res := make(chan []byte, 1) - r := &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - return goResolverConn(res), nil - }, - } - r.LookupIP(context.Background(), "ip4", name) - return <-res -} - -type goResolverConn chan<- []byte - -func (goResolverConn) Close() error { return nil } -func (goResolverConn) LocalAddr() net.Addr { return todoAddr{} } -func (goResolverConn) RemoteAddr() net.Addr { return todoAddr{} } -func (goResolverConn) SetDeadline(t time.Time) error { return nil } -func (goResolverConn) SetReadDeadline(t time.Time) error { return nil } -func (goResolverConn) SetWriteDeadline(t time.Time) error { return nil } -func (goResolverConn) Read([]byte) (int, error) { return 0, errors.New("boom") } -func (c goResolverConn) Write(p []byte) (int, error) { - select { - case c <- p[2:]: // skip 2 byte length for TCP mode DNS query - default: - } - return 0, errors.New("boom") -} - -type todoAddr struct{} - -func (todoAddr) Network() string { return "unused" } -func (todoAddr) String() string { return "unused-todoAddr" } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dnscache + +import ( + "bytes" + "context" + "errors" + "fmt" + "net" + "runtime" + "testing" + "time" + + "golang.org/x/net/dns/dnsmessage" + "tailscale.com/tstest" +) + +func TestMessageCache(t *testing.T) { + clock := tstest.NewClock(tstest.ClockOpts{ + Start: time.Date(1987, 11, 1, 0, 0, 0, 0, time.UTC), + }) + mc := &MessageCache{Clock: clock.Now} + mc.SetMaxCacheSize(2) + clock.Advance(time.Second) + + var out bytes.Buffer + if err := mc.ReplyFromCache(&out, makeQ(1, "foo.com.")); err != ErrCacheMiss { + t.Fatalf("unexpected error: %v", err) + } + + if err := mc.AddCacheEntry( + makeQ(2, "foo.com."), + makeRes(2, "FOO.COM.", ttlOpt(10), + &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}, + &dnsmessage.AResource{A: [4]byte{127, 0, 0, 2}})); err != nil { + t.Fatal(err) + } + + // Expect cache hit, with 10 seconds remaining. + out.Reset() + if err := mc.ReplyFromCache(&out, makeQ(3, "foo.com.")); err != nil { + t.Fatalf("expected cache hit; got: %v", err) + } + if p := mustParseResponse(t, out.Bytes()); p.TxID != 3 { + t.Errorf("TxID = %v; want %v", p.TxID, 3) + } else if p.TTL != 10 { + t.Errorf("TTL = %v; want 10", p.TTL) + } + + // One second elapses, expect a cache hit, with 9 seconds + // remaining. + clock.Advance(time.Second) + out.Reset() + if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.")); err != nil { + t.Fatalf("expected cache hit; got: %v", err) + } + if p := mustParseResponse(t, out.Bytes()); p.TxID != 4 { + t.Errorf("TxID = %v; want %v", p.TxID, 4) + } else if p.TTL != 9 { + t.Errorf("TTL = %v; want 9", p.TTL) + } + + // Expect cache miss on MX record. + if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.TypeMX)); err != ErrCacheMiss { + t.Fatalf("expected cache miss on MX; got: %v", err) + } + // Expect cache miss on CHAOS class. + if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.ClassCHAOS)); err != ErrCacheMiss { + t.Fatalf("expected cache miss on CHAOS; got: %v", err) + } + + // Ten seconds elapses; expect a cache miss. + clock.Advance(10 * time.Second) + if err := mc.ReplyFromCache(&out, makeQ(5, "foo.com.")); err != ErrCacheMiss { + t.Fatalf("expected cache miss, got: %v", err) + } +} + +type parsedMeta struct { + TxID uint16 + TTL uint32 +} + +func mustParseResponse(t testing.TB, r []byte) (ret parsedMeta) { + t.Helper() + var p dnsmessage.Parser + h, err := p.Start(r) + if err != nil { + t.Fatal(err) + } + ret.TxID = h.ID + qq, err := p.AllQuestions() + if err != nil { + t.Fatalf("AllQuestions: %v", err) + } + if len(qq) != 1 { + t.Fatalf("num questions = %v; want 1", len(qq)) + } + aa, err := p.AllAnswers() + if err != nil { + t.Fatalf("AllAnswers: %v", err) + } + for _, r := range aa { + if ret.TTL == 0 { + ret.TTL = r.Header.TTL + } + if ret.TTL != r.Header.TTL { + t.Fatal("mixed TTLs") + } + } + return ret +} + +type responseOpt bool + +type ttlOpt uint32 + +func makeQ(txID uint16, name string, opt ...any) []byte { + opt = append(opt, responseOpt(false)) + return makeDNSPkt(txID, name, opt...) +} + +func makeRes(txID uint16, name string, opt ...any) []byte { + opt = append(opt, responseOpt(true)) + return makeDNSPkt(txID, name, opt...) +} + +func makeDNSPkt(txID uint16, name string, opt ...any) []byte { + typ := dnsmessage.TypeA + class := dnsmessage.ClassINET + var response bool + var answers []dnsmessage.ResourceBody + var ttl uint32 = 1 // one second by default + for _, o := range opt { + switch o := o.(type) { + case dnsmessage.Type: + typ = o + case dnsmessage.Class: + class = o + case responseOpt: + response = bool(o) + case dnsmessage.ResourceBody: + answers = append(answers, o) + case ttlOpt: + ttl = uint32(o) + default: + panic(fmt.Sprintf("unknown opt type %T", o)) + } + } + qname := dnsmessage.MustNewName(name) + msg := dnsmessage.Message{ + Header: dnsmessage.Header{ID: txID, Response: response}, + Questions: []dnsmessage.Question{ + { + Name: qname, + Type: typ, + Class: class, + }, + }, + } + for _, rb := range answers { + msg.Answers = append(msg.Answers, dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: qname, + Type: typ, + Class: class, + TTL: ttl, + }, + Body: rb, + }) + } + buf, err := msg.Pack() + if err != nil { + panic(err) + } + return buf +} + +func TestASCIILowerName(t *testing.T) { + n := asciiLowerName(dnsmessage.MustNewName("Foo.COM.")) + if got, want := n.String(), "foo.com."; got != want { + t.Errorf("got = %q; want %q", got, want) + } +} + +func TestGetDNSQueryCacheKey(t *testing.T) { + tests := []struct { + name string + pkt []byte + want msgQ + txID uint16 + anyTX bool + }{ + { + name: "empty", + }, + { + name: "a", + pkt: makeQ(123, "foo.com."), + want: msgQ{"foo.com.", dnsmessage.TypeA}, + txID: 123, + }, + { + name: "aaaa", + pkt: makeQ(6, "foo.com.", dnsmessage.TypeAAAA), + want: msgQ{"foo.com.", dnsmessage.TypeAAAA}, + txID: 6, + }, + { + name: "normalize_case", + pkt: makeQ(123, "FoO.CoM."), + want: msgQ{"foo.com.", dnsmessage.TypeA}, + txID: 123, + }, + { + name: "ignore_response", + pkt: makeRes(123, "foo.com."), + }, + { + name: "ignore_question_with_answers", + pkt: makeQ(2, "foo.com.", &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}), + }, + { + name: "whatever_go_generates", // in case Go's net package grows functionality we don't handle + pkt: getGoNetPacketDNSQuery("from-go.foo."), + want: msgQ{"from-go.foo.", dnsmessage.TypeA}, + anyTX: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, gotTX, ok := getDNSQueryCacheKey(tt.pkt) + if !ok { + if tt.txID == 0 && got == (msgQ{}) { + return + } + t.Fatal("failed") + } + if got != tt.want { + t.Errorf("got %+v, want %+v", got, tt.want) + } + if gotTX != tt.txID && !tt.anyTX { + t.Errorf("got tx %v, want %v", gotTX, tt.txID) + } + }) + } +} + +func getGoNetPacketDNSQuery(name string) []byte { + if runtime.GOOS == "windows" { + // On Windows, Go's net.Resolver doesn't use the DNS client. + // See https://github.com/golang/go/issues/33097 which + // was approved but not yet implemented. + // For now just pretend it's implemented to make this test + // pass on Windows with complicated the caller. + return makeQ(123, name) + } + res := make(chan []byte, 1) + r := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + return goResolverConn(res), nil + }, + } + r.LookupIP(context.Background(), "ip4", name) + return <-res +} + +type goResolverConn chan<- []byte + +func (goResolverConn) Close() error { return nil } +func (goResolverConn) LocalAddr() net.Addr { return todoAddr{} } +func (goResolverConn) RemoteAddr() net.Addr { return todoAddr{} } +func (goResolverConn) SetDeadline(t time.Time) error { return nil } +func (goResolverConn) SetReadDeadline(t time.Time) error { return nil } +func (goResolverConn) SetWriteDeadline(t time.Time) error { return nil } +func (goResolverConn) Read([]byte) (int, error) { return 0, errors.New("boom") } +func (c goResolverConn) Write(p []byte) (int, error) { + select { + case c <- p[2:]: // skip 2 byte length for TCP mode DNS query + default: + } + return 0, errors.New("boom") +} + +type todoAddr struct{} + +func (todoAddr) Network() string { return "unused" } +func (todoAddr) String() string { return "unused-todoAddr" } diff --git a/net/dnsfallback/update-dns-fallbacks.go b/net/dnsfallback/update-dns-fallbacks.go index ebbfc2ad1..384e77e10 100644 --- a/net/dnsfallback/update-dns-fallbacks.go +++ b/net/dnsfallback/update-dns-fallbacks.go @@ -1,45 +1,45 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ignore - -package main - -import ( - "encoding/json" - "fmt" - "log" - "net/http" - "os" - - "tailscale.com/tailcfg" -) - -func main() { - res, err := http.Get("https://login.tailscale.com/derpmap/default") - if err != nil { - log.Fatal(err) - } - if res.StatusCode != 200 { - res.Write(os.Stderr) - os.Exit(1) - } - dm := new(tailcfg.DERPMap) - if err := json.NewDecoder(res.Body).Decode(dm); err != nil { - log.Fatal(err) - } - for rid, r := range dm.Regions { - // Names misleading to check into git, as this is a - // static snapshot and doesn't reflect the live DERP - // map. - r.RegionCode = fmt.Sprintf("r%d", rid) - r.RegionName = r.RegionCode - } - out, err := json.MarshalIndent(dm, "", "\t") - if err != nil { - log.Fatal(err) - } - if err := os.WriteFile("dns-fallback-servers.json", out, 0644); err != nil { - log.Fatal(err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +package main + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "os" + + "tailscale.com/tailcfg" +) + +func main() { + res, err := http.Get("https://login.tailscale.com/derpmap/default") + if err != nil { + log.Fatal(err) + } + if res.StatusCode != 200 { + res.Write(os.Stderr) + os.Exit(1) + } + dm := new(tailcfg.DERPMap) + if err := json.NewDecoder(res.Body).Decode(dm); err != nil { + log.Fatal(err) + } + for rid, r := range dm.Regions { + // Names misleading to check into git, as this is a + // static snapshot and doesn't reflect the live DERP + // map. + r.RegionCode = fmt.Sprintf("r%d", rid) + r.RegionName = r.RegionCode + } + out, err := json.MarshalIndent(dm, "", "\t") + if err != nil { + log.Fatal(err) + } + if err := os.WriteFile("dns-fallback-servers.json", out, 0644); err != nil { + log.Fatal(err) + } +} diff --git a/net/memnet/conn.go b/net/memnet/conn.go index f599612d9..a9e1fd399 100644 --- a/net/memnet/conn.go +++ b/net/memnet/conn.go @@ -1,114 +1,114 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "net" - "net/netip" - "time" -) - -// NetworkName is the network name returned by [net.Addr.Network] -// for [net.Conn.LocalAddr] and [net.Conn.RemoteAddr] from the [Conn] type. -const NetworkName = "mem" - -// Conn is a net.Conn that can additionally have its reads and writes blocked and unblocked. -type Conn interface { - net.Conn - - // SetReadBlock blocks or unblocks the Read method of this Conn. - // It reports an error if the existing value matches the new value, - // or if the Conn has been Closed. - SetReadBlock(bool) error - - // SetWriteBlock blocks or unblocks the Write method of this Conn. - // It reports an error if the existing value matches the new value, - // or if the Conn has been Closed. - SetWriteBlock(bool) error -} - -// NewConn creates a pair of Conns that are wired together by pipes. -func NewConn(name string, maxBuf int) (Conn, Conn) { - r := NewPipe(name+"|0", maxBuf) - w := NewPipe(name+"|1", maxBuf) - - return &connHalf{r: r, w: w}, &connHalf{r: w, w: r} -} - -// NewTCPConn creates a pair of Conns that are wired together by pipes. -func NewTCPConn(src, dst netip.AddrPort, maxBuf int) (local Conn, remote Conn) { - r := NewPipe(src.String(), maxBuf) - w := NewPipe(dst.String(), maxBuf) - - lAddr := net.TCPAddrFromAddrPort(src) - rAddr := net.TCPAddrFromAddrPort(dst) - - return &connHalf{r: r, w: w, remote: rAddr, local: lAddr}, &connHalf{r: w, w: r, remote: lAddr, local: rAddr} -} - -type connAddr string - -func (a connAddr) Network() string { return NetworkName } -func (a connAddr) String() string { return string(a) } - -type connHalf struct { - local, remote net.Addr - r, w *Pipe -} - -func (c *connHalf) LocalAddr() net.Addr { - if c.local != nil { - return c.local - } - return connAddr(c.r.name) -} - -func (c *connHalf) RemoteAddr() net.Addr { - if c.remote != nil { - return c.remote - } - return connAddr(c.w.name) -} - -func (c *connHalf) Read(b []byte) (n int, err error) { - return c.r.Read(b) -} -func (c *connHalf) Write(b []byte) (n int, err error) { - return c.w.Write(b) -} - -func (c *connHalf) Close() error { - if err := c.w.Close(); err != nil { - return err - } - return c.r.Close() -} - -func (c *connHalf) SetDeadline(t time.Time) error { - err1 := c.SetReadDeadline(t) - err2 := c.SetWriteDeadline(t) - if err1 != nil { - return err1 - } - return err2 -} -func (c *connHalf) SetReadDeadline(t time.Time) error { - return c.r.SetReadDeadline(t) -} -func (c *connHalf) SetWriteDeadline(t time.Time) error { - return c.w.SetWriteDeadline(t) -} - -func (c *connHalf) SetReadBlock(b bool) error { - if b { - return c.r.Block() - } - return c.r.Unblock() -} -func (c *connHalf) SetWriteBlock(b bool) error { - if b { - return c.w.Block() - } - return c.w.Unblock() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "net" + "net/netip" + "time" +) + +// NetworkName is the network name returned by [net.Addr.Network] +// for [net.Conn.LocalAddr] and [net.Conn.RemoteAddr] from the [Conn] type. +const NetworkName = "mem" + +// Conn is a net.Conn that can additionally have its reads and writes blocked and unblocked. +type Conn interface { + net.Conn + + // SetReadBlock blocks or unblocks the Read method of this Conn. + // It reports an error if the existing value matches the new value, + // or if the Conn has been Closed. + SetReadBlock(bool) error + + // SetWriteBlock blocks or unblocks the Write method of this Conn. + // It reports an error if the existing value matches the new value, + // or if the Conn has been Closed. + SetWriteBlock(bool) error +} + +// NewConn creates a pair of Conns that are wired together by pipes. +func NewConn(name string, maxBuf int) (Conn, Conn) { + r := NewPipe(name+"|0", maxBuf) + w := NewPipe(name+"|1", maxBuf) + + return &connHalf{r: r, w: w}, &connHalf{r: w, w: r} +} + +// NewTCPConn creates a pair of Conns that are wired together by pipes. +func NewTCPConn(src, dst netip.AddrPort, maxBuf int) (local Conn, remote Conn) { + r := NewPipe(src.String(), maxBuf) + w := NewPipe(dst.String(), maxBuf) + + lAddr := net.TCPAddrFromAddrPort(src) + rAddr := net.TCPAddrFromAddrPort(dst) + + return &connHalf{r: r, w: w, remote: rAddr, local: lAddr}, &connHalf{r: w, w: r, remote: lAddr, local: rAddr} +} + +type connAddr string + +func (a connAddr) Network() string { return NetworkName } +func (a connAddr) String() string { return string(a) } + +type connHalf struct { + local, remote net.Addr + r, w *Pipe +} + +func (c *connHalf) LocalAddr() net.Addr { + if c.local != nil { + return c.local + } + return connAddr(c.r.name) +} + +func (c *connHalf) RemoteAddr() net.Addr { + if c.remote != nil { + return c.remote + } + return connAddr(c.w.name) +} + +func (c *connHalf) Read(b []byte) (n int, err error) { + return c.r.Read(b) +} +func (c *connHalf) Write(b []byte) (n int, err error) { + return c.w.Write(b) +} + +func (c *connHalf) Close() error { + if err := c.w.Close(); err != nil { + return err + } + return c.r.Close() +} + +func (c *connHalf) SetDeadline(t time.Time) error { + err1 := c.SetReadDeadline(t) + err2 := c.SetWriteDeadline(t) + if err1 != nil { + return err1 + } + return err2 +} +func (c *connHalf) SetReadDeadline(t time.Time) error { + return c.r.SetReadDeadline(t) +} +func (c *connHalf) SetWriteDeadline(t time.Time) error { + return c.w.SetWriteDeadline(t) +} + +func (c *connHalf) SetReadBlock(b bool) error { + if b { + return c.r.Block() + } + return c.r.Unblock() +} +func (c *connHalf) SetWriteBlock(b bool) error { + if b { + return c.w.Block() + } + return c.w.Unblock() +} diff --git a/net/memnet/conn_test.go b/net/memnet/conn_test.go index 3eec80bc6..743ce5248 100644 --- a/net/memnet/conn_test.go +++ b/net/memnet/conn_test.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "net" - "testing" - - "golang.org/x/net/nettest" -) - -func TestConn(t *testing.T) { - nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) { - c1, c2 = NewConn("test", bufferSize) - return c1, c2, func() { - c1.Close() - c2.Close() - }, nil - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "net" + "testing" + + "golang.org/x/net/nettest" +) + +func TestConn(t *testing.T) { + nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) { + c1, c2 = NewConn("test", bufferSize) + return c1, c2, func() { + c1.Close() + c2.Close() + }, nil + }) +} diff --git a/net/memnet/listener.go b/net/memnet/listener.go index d1364d790..d84a2e443 100644 --- a/net/memnet/listener.go +++ b/net/memnet/listener.go @@ -1,100 +1,100 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "context" - "net" - "strings" - "sync" -) - -const ( - bufferSize = 256 * 1024 -) - -// Listener is a net.Listener using NewConn to create pairs of network -// connections connected in memory using a buffered pipe. It also provides a -// Dial method to establish new connections. -type Listener struct { - addr connAddr - ch chan Conn - closeOnce sync.Once - closed chan struct{} - - // NewConn, if non-nil, is called to create a new pair of connections - // when dialing. If nil, NewConn is used. - NewConn func(network, addr string, maxBuf int) (Conn, Conn) -} - -// Listen returns a new Listener for the provided address. -func Listen(addr string) *Listener { - return &Listener{ - addr: connAddr(addr), - ch: make(chan Conn), - closed: make(chan struct{}), - } -} - -// Addr implements net.Listener.Addr. -func (l *Listener) Addr() net.Addr { - return l.addr -} - -// Close closes the pipe listener. -func (l *Listener) Close() error { - l.closeOnce.Do(func() { - close(l.closed) - }) - return nil -} - -// Accept blocks until a new connection is available or the listener is closed. -func (l *Listener) Accept() (net.Conn, error) { - select { - case c := <-l.ch: - return c, nil - case <-l.closed: - return nil, net.ErrClosed - } -} - -// Dial connects to the listener using the provided context. -// The provided Context must be non-nil. If the context expires before the -// connection is complete, an error is returned. Once successfully connected -// any expiration of the context will not affect the connection. -func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) { - if !strings.HasSuffix(network, "tcp") { - return nil, net.UnknownNetworkError(network) - } - if connAddr(addr) != l.addr { - return nil, &net.AddrError{ - Err: "invalid address", - Addr: addr, - } - } - - newConn := l.NewConn - if newConn == nil { - newConn = func(network, addr string, maxBuf int) (Conn, Conn) { - return NewConn(addr, maxBuf) - } - } - c, s := newConn(network, addr, bufferSize) - defer func() { - if err != nil { - c.Close() - s.Close() - } - }() - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-l.closed: - return nil, net.ErrClosed - case l.ch <- s: - return c, nil - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "context" + "net" + "strings" + "sync" +) + +const ( + bufferSize = 256 * 1024 +) + +// Listener is a net.Listener using NewConn to create pairs of network +// connections connected in memory using a buffered pipe. It also provides a +// Dial method to establish new connections. +type Listener struct { + addr connAddr + ch chan Conn + closeOnce sync.Once + closed chan struct{} + + // NewConn, if non-nil, is called to create a new pair of connections + // when dialing. If nil, NewConn is used. + NewConn func(network, addr string, maxBuf int) (Conn, Conn) +} + +// Listen returns a new Listener for the provided address. +func Listen(addr string) *Listener { + return &Listener{ + addr: connAddr(addr), + ch: make(chan Conn), + closed: make(chan struct{}), + } +} + +// Addr implements net.Listener.Addr. +func (l *Listener) Addr() net.Addr { + return l.addr +} + +// Close closes the pipe listener. +func (l *Listener) Close() error { + l.closeOnce.Do(func() { + close(l.closed) + }) + return nil +} + +// Accept blocks until a new connection is available or the listener is closed. +func (l *Listener) Accept() (net.Conn, error) { + select { + case c := <-l.ch: + return c, nil + case <-l.closed: + return nil, net.ErrClosed + } +} + +// Dial connects to the listener using the provided context. +// The provided Context must be non-nil. If the context expires before the +// connection is complete, an error is returned. Once successfully connected +// any expiration of the context will not affect the connection. +func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) { + if !strings.HasSuffix(network, "tcp") { + return nil, net.UnknownNetworkError(network) + } + if connAddr(addr) != l.addr { + return nil, &net.AddrError{ + Err: "invalid address", + Addr: addr, + } + } + + newConn := l.NewConn + if newConn == nil { + newConn = func(network, addr string, maxBuf int) (Conn, Conn) { + return NewConn(addr, maxBuf) + } + } + c, s := newConn(network, addr, bufferSize) + defer func() { + if err != nil { + c.Close() + s.Close() + } + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-l.closed: + return nil, net.ErrClosed + case l.ch <- s: + return c, nil + } +} diff --git a/net/memnet/listener_test.go b/net/memnet/listener_test.go index 989d5e9e4..73b67841a 100644 --- a/net/memnet/listener_test.go +++ b/net/memnet/listener_test.go @@ -1,33 +1,33 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "context" - "testing" -) - -func TestListener(t *testing.T) { - l := Listen("srv.local") - defer l.Close() - go func() { - c, err := l.Accept() - if err != nil { - t.Error(err) - return - } - defer c.Close() - }() - - if c, err := l.Dial(context.Background(), "tcp", "invalid"); err == nil { - c.Close() - t.Fatalf("dial to invalid address succeeded") - } - c, err := l.Dial(context.Background(), "tcp", "srv.local") - if err != nil { - t.Fatalf("dial failed: %v", err) - return - } - c.Close() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "context" + "testing" +) + +func TestListener(t *testing.T) { + l := Listen("srv.local") + defer l.Close() + go func() { + c, err := l.Accept() + if err != nil { + t.Error(err) + return + } + defer c.Close() + }() + + if c, err := l.Dial(context.Background(), "tcp", "invalid"); err == nil { + c.Close() + t.Fatalf("dial to invalid address succeeded") + } + c, err := l.Dial(context.Background(), "tcp", "srv.local") + if err != nil { + t.Fatalf("dial failed: %v", err) + return + } + c.Close() +} diff --git a/net/memnet/memnet.go b/net/memnet/memnet.go index 2fc13b4b2..c8799bc17 100644 --- a/net/memnet/memnet.go +++ b/net/memnet/memnet.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package memnet implements an in-memory network implementation. -// It is useful for dialing and listening on in-memory addresses -// in tests and other situations where you don't want to use the -// network. -package memnet +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package memnet implements an in-memory network implementation. +// It is useful for dialing and listening on in-memory addresses +// in tests and other situations where you don't want to use the +// network. +package memnet diff --git a/net/memnet/pipe.go b/net/memnet/pipe.go index 51bee1090..471635083 100644 --- a/net/memnet/pipe.go +++ b/net/memnet/pipe.go @@ -1,244 +1,244 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "bytes" - "context" - "fmt" - "io" - "log" - "net" - "os" - "sync" - "time" -) - -const debugPipe = false - -// Pipe implements an in-memory FIFO with timeouts. -type Pipe struct { - name string - maxBuf int - mu sync.Mutex - cnd *sync.Cond - - blocked bool - closed bool - buf bytes.Buffer - readTimeout time.Time - writeTimeout time.Time - cancelReadTimer func() - cancelWriteTimer func() -} - -// NewPipe creates a Pipe with a buffer size fixed at maxBuf. -func NewPipe(name string, maxBuf int) *Pipe { - p := &Pipe{ - name: name, - maxBuf: maxBuf, - } - p.cnd = sync.NewCond(&p.mu) - return p -} - -// readOrBlock attempts to read from the buffer, if the buffer is empty and -// the connection hasn't been closed it will block until there is a change. -func (p *Pipe) readOrBlock(b []byte) (int, error) { - p.mu.Lock() - defer p.mu.Unlock() - if !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) { - return 0, os.ErrDeadlineExceeded - } - if p.blocked { - p.cnd.Wait() - return 0, nil - } - - n, err := p.buf.Read(b) - // err will either be nil or io.EOF. - if err == io.EOF { - if p.closed { - return n, err - } - // Wait for something to change. - p.cnd.Wait() - } - return n, nil -} - -// Read implements io.Reader. -// Once the buffer is drained (i.e. after Close), subsequent calls will -// return io.EOF. -func (p *Pipe) Read(b []byte) (n int, err error) { - if debugPipe { - orig := b - defer func() { - log.Printf("Pipe(%q).Read(%q) n=%d, err=%v", p.name, string(orig[:n]), n, err) - }() - } - for n == 0 { - n2, err := p.readOrBlock(b) - if err != nil { - return n2, err - } - n += n2 - } - p.cnd.Signal() - return n, nil -} - -// writeOrBlock attempts to write to the buffer, if the buffer is full it will -// block until there is a change. -func (p *Pipe) writeOrBlock(b []byte) (int, error) { - p.mu.Lock() - defer p.mu.Unlock() - if p.closed { - return 0, net.ErrClosed - } - if !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) { - return 0, os.ErrDeadlineExceeded - } - if p.blocked { - p.cnd.Wait() - return 0, nil - } - - // Optimistically we want to write the entire slice. - n := len(b) - if limit := p.maxBuf - p.buf.Len(); limit < n { - // However, we don't have enough capacity to write everything. - n = limit - } - if n == 0 { - // Wait for something to change. - p.cnd.Wait() - return 0, nil - } - - p.buf.Write(b[:n]) - p.cnd.Signal() - return n, nil -} - -// Write implements io.Writer. -func (p *Pipe) Write(b []byte) (n int, err error) { - if debugPipe { - orig := b - defer func() { - log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err) - }() - } - for len(b) > 0 { - n2, err := p.writeOrBlock(b) - if err != nil { - return n + n2, err - } - n += n2 - b = b[n2:] - } - return n, nil -} - -// Close closes the pipe. -func (p *Pipe) Close() error { - p.mu.Lock() - defer p.mu.Unlock() - p.closed = true - p.blocked = false - if p.cancelWriteTimer != nil { - p.cancelWriteTimer() - p.cancelWriteTimer = nil - } - if p.cancelReadTimer != nil { - p.cancelReadTimer() - p.cancelReadTimer = nil - } - p.cnd.Broadcast() - - return nil -} - -func (p *Pipe) deadlineTimer(t time.Time) func() { - if t.IsZero() { - return nil - } - if t.Before(time.Now()) { - p.cnd.Broadcast() - return nil - } - ctx, cancel := context.WithDeadline(context.Background(), t) - go func() { - <-ctx.Done() - if ctx.Err() == context.DeadlineExceeded { - p.cnd.Broadcast() - } - }() - return cancel -} - -// SetReadDeadline sets the deadline for future Read calls. -func (p *Pipe) SetReadDeadline(t time.Time) error { - p.mu.Lock() - defer p.mu.Unlock() - p.readTimeout = t - // If we already have a deadline, cancel it and create a new one. - if p.cancelReadTimer != nil { - p.cancelReadTimer() - p.cancelReadTimer = nil - } - p.cancelReadTimer = p.deadlineTimer(t) - return nil -} - -// SetWriteDeadline sets the deadline for future Write calls. -func (p *Pipe) SetWriteDeadline(t time.Time) error { - p.mu.Lock() - defer p.mu.Unlock() - p.writeTimeout = t - // If we already have a deadline, cancel it and create a new one. - if p.cancelWriteTimer != nil { - p.cancelWriteTimer() - p.cancelWriteTimer = nil - } - p.cancelWriteTimer = p.deadlineTimer(t) - return nil -} - -// Block will cause all calls to Read and Write to block until they either -// timeout, are unblocked or the pipe is closed. -func (p *Pipe) Block() error { - p.mu.Lock() - defer p.mu.Unlock() - closed := p.closed - blocked := p.blocked - p.blocked = true - - if closed { - return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name) - } - if blocked { - return fmt.Errorf("memnet.Pipe(%q).Block: already blocked", p.name) - } - p.cnd.Broadcast() - return nil -} - -// Unblock will cause all blocked Read/Write calls to continue execution. -func (p *Pipe) Unblock() error { - p.mu.Lock() - defer p.mu.Unlock() - closed := p.closed - blocked := p.blocked - p.blocked = false - - if closed { - return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name) - } - if !blocked { - return fmt.Errorf("memnet.Pipe(%q).Block: already unblocked", p.name) - } - p.cnd.Broadcast() - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "bytes" + "context" + "fmt" + "io" + "log" + "net" + "os" + "sync" + "time" +) + +const debugPipe = false + +// Pipe implements an in-memory FIFO with timeouts. +type Pipe struct { + name string + maxBuf int + mu sync.Mutex + cnd *sync.Cond + + blocked bool + closed bool + buf bytes.Buffer + readTimeout time.Time + writeTimeout time.Time + cancelReadTimer func() + cancelWriteTimer func() +} + +// NewPipe creates a Pipe with a buffer size fixed at maxBuf. +func NewPipe(name string, maxBuf int) *Pipe { + p := &Pipe{ + name: name, + maxBuf: maxBuf, + } + p.cnd = sync.NewCond(&p.mu) + return p +} + +// readOrBlock attempts to read from the buffer, if the buffer is empty and +// the connection hasn't been closed it will block until there is a change. +func (p *Pipe) readOrBlock(b []byte) (int, error) { + p.mu.Lock() + defer p.mu.Unlock() + if !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) { + return 0, os.ErrDeadlineExceeded + } + if p.blocked { + p.cnd.Wait() + return 0, nil + } + + n, err := p.buf.Read(b) + // err will either be nil or io.EOF. + if err == io.EOF { + if p.closed { + return n, err + } + // Wait for something to change. + p.cnd.Wait() + } + return n, nil +} + +// Read implements io.Reader. +// Once the buffer is drained (i.e. after Close), subsequent calls will +// return io.EOF. +func (p *Pipe) Read(b []byte) (n int, err error) { + if debugPipe { + orig := b + defer func() { + log.Printf("Pipe(%q).Read(%q) n=%d, err=%v", p.name, string(orig[:n]), n, err) + }() + } + for n == 0 { + n2, err := p.readOrBlock(b) + if err != nil { + return n2, err + } + n += n2 + } + p.cnd.Signal() + return n, nil +} + +// writeOrBlock attempts to write to the buffer, if the buffer is full it will +// block until there is a change. +func (p *Pipe) writeOrBlock(b []byte) (int, error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.closed { + return 0, net.ErrClosed + } + if !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) { + return 0, os.ErrDeadlineExceeded + } + if p.blocked { + p.cnd.Wait() + return 0, nil + } + + // Optimistically we want to write the entire slice. + n := len(b) + if limit := p.maxBuf - p.buf.Len(); limit < n { + // However, we don't have enough capacity to write everything. + n = limit + } + if n == 0 { + // Wait for something to change. + p.cnd.Wait() + return 0, nil + } + + p.buf.Write(b[:n]) + p.cnd.Signal() + return n, nil +} + +// Write implements io.Writer. +func (p *Pipe) Write(b []byte) (n int, err error) { + if debugPipe { + orig := b + defer func() { + log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err) + }() + } + for len(b) > 0 { + n2, err := p.writeOrBlock(b) + if err != nil { + return n + n2, err + } + n += n2 + b = b[n2:] + } + return n, nil +} + +// Close closes the pipe. +func (p *Pipe) Close() error { + p.mu.Lock() + defer p.mu.Unlock() + p.closed = true + p.blocked = false + if p.cancelWriteTimer != nil { + p.cancelWriteTimer() + p.cancelWriteTimer = nil + } + if p.cancelReadTimer != nil { + p.cancelReadTimer() + p.cancelReadTimer = nil + } + p.cnd.Broadcast() + + return nil +} + +func (p *Pipe) deadlineTimer(t time.Time) func() { + if t.IsZero() { + return nil + } + if t.Before(time.Now()) { + p.cnd.Broadcast() + return nil + } + ctx, cancel := context.WithDeadline(context.Background(), t) + go func() { + <-ctx.Done() + if ctx.Err() == context.DeadlineExceeded { + p.cnd.Broadcast() + } + }() + return cancel +} + +// SetReadDeadline sets the deadline for future Read calls. +func (p *Pipe) SetReadDeadline(t time.Time) error { + p.mu.Lock() + defer p.mu.Unlock() + p.readTimeout = t + // If we already have a deadline, cancel it and create a new one. + if p.cancelReadTimer != nil { + p.cancelReadTimer() + p.cancelReadTimer = nil + } + p.cancelReadTimer = p.deadlineTimer(t) + return nil +} + +// SetWriteDeadline sets the deadline for future Write calls. +func (p *Pipe) SetWriteDeadline(t time.Time) error { + p.mu.Lock() + defer p.mu.Unlock() + p.writeTimeout = t + // If we already have a deadline, cancel it and create a new one. + if p.cancelWriteTimer != nil { + p.cancelWriteTimer() + p.cancelWriteTimer = nil + } + p.cancelWriteTimer = p.deadlineTimer(t) + return nil +} + +// Block will cause all calls to Read and Write to block until they either +// timeout, are unblocked or the pipe is closed. +func (p *Pipe) Block() error { + p.mu.Lock() + defer p.mu.Unlock() + closed := p.closed + blocked := p.blocked + p.blocked = true + + if closed { + return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name) + } + if blocked { + return fmt.Errorf("memnet.Pipe(%q).Block: already blocked", p.name) + } + p.cnd.Broadcast() + return nil +} + +// Unblock will cause all blocked Read/Write calls to continue execution. +func (p *Pipe) Unblock() error { + p.mu.Lock() + defer p.mu.Unlock() + closed := p.closed + blocked := p.blocked + p.blocked = false + + if closed { + return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name) + } + if !blocked { + return fmt.Errorf("memnet.Pipe(%q).Block: already unblocked", p.name) + } + p.cnd.Broadcast() + return nil +} diff --git a/net/memnet/pipe_test.go b/net/memnet/pipe_test.go index b3775cf7f..a86d65388 100644 --- a/net/memnet/pipe_test.go +++ b/net/memnet/pipe_test.go @@ -1,117 +1,117 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package memnet - -import ( - "errors" - "fmt" - "os" - "testing" - "time" -) - -func TestPipeHello(t *testing.T) { - p := NewPipe("p1", 1<<16) - msg := "Hello, World!" - if n, err := p.Write([]byte(msg)); err != nil { - t.Fatal(err) - } else if n != len(msg) { - t.Errorf("p.Write(%q) n=%d, want %d", msg, n, len(msg)) - } - b := make([]byte, len(msg)) - if n, err := p.Read(b); err != nil { - t.Fatal(err) - } else if n != len(b) { - t.Errorf("p.Read(%q) n=%d, want %d", string(b[:n]), n, len(b)) - } - if got := string(b); got != msg { - t.Errorf("p.Read: %q, want %q", got, msg) - } -} - -func TestPipeTimeout(t *testing.T) { - t.Run("write", func(t *testing.T) { - p := NewPipe("p1", 1<<16) - p.SetWriteDeadline(time.Now().Add(-1 * time.Second)) - n, err := p.Write([]byte{'h'}) - if !errors.Is(err, os.ErrDeadlineExceeded) { - t.Errorf("missing write timeout got err: %v", err) - } - if n != 0 { - t.Errorf("n=%d on timeout", n) - } - }) - t.Run("read", func(t *testing.T) { - p := NewPipe("p1", 1<<16) - p.Write([]byte{'h'}) - - p.SetReadDeadline(time.Now().Add(-1 * time.Second)) - b := make([]byte, 1) - n, err := p.Read(b) - if !errors.Is(err, os.ErrDeadlineExceeded) { - t.Errorf("missing read timeout got err: %v", err) - } - if n != 0 { - t.Errorf("n=%d on timeout", n) - } - }) - t.Run("block-write", func(t *testing.T) { - p := NewPipe("p1", 1<<16) - p.SetWriteDeadline(time.Now().Add(10 * time.Millisecond)) - if err := p.Block(); err != nil { - t.Fatal(err) - } - if _, err := p.Write([]byte{'h'}); !errors.Is(err, os.ErrDeadlineExceeded) { - t.Fatalf("want write timeout got: %v", err) - } - }) - t.Run("block-read", func(t *testing.T) { - p := NewPipe("p1", 1<<16) - p.Write([]byte{'h', 'i'}) - p.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) - b := make([]byte, 1) - if err := p.Block(); err != nil { - t.Fatal(err) - } - if _, err := p.Read(b); !errors.Is(err, os.ErrDeadlineExceeded) { - t.Fatalf("want read timeout got: %v", err) - } - }) -} - -func TestLimit(t *testing.T) { - p := NewPipe("p1", 1) - errCh := make(chan error) - go func() { - n, err := p.Write([]byte{'a', 'b', 'c'}) - if err != nil { - errCh <- err - } else if n != 3 { - errCh <- fmt.Errorf("p.Write n=%d, want 3", n) - } else { - errCh <- nil - } - }() - b := make([]byte, 3) - - if n, err := p.Read(b); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Errorf("Read(%q): n=%d want 1", string(b), n) - } - if n, err := p.Read(b); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Errorf("Read(%q): n=%d want 1", string(b), n) - } - if n, err := p.Read(b); err != nil { - t.Fatal(err) - } else if n != 1 { - t.Errorf("Read(%q): n=%d want 1", string(b), n) - } - - if err := <-errCh; err != nil { - t.Error(err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import ( + "errors" + "fmt" + "os" + "testing" + "time" +) + +func TestPipeHello(t *testing.T) { + p := NewPipe("p1", 1<<16) + msg := "Hello, World!" + if n, err := p.Write([]byte(msg)); err != nil { + t.Fatal(err) + } else if n != len(msg) { + t.Errorf("p.Write(%q) n=%d, want %d", msg, n, len(msg)) + } + b := make([]byte, len(msg)) + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != len(b) { + t.Errorf("p.Read(%q) n=%d, want %d", string(b[:n]), n, len(b)) + } + if got := string(b); got != msg { + t.Errorf("p.Read: %q, want %q", got, msg) + } +} + +func TestPipeTimeout(t *testing.T) { + t.Run("write", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.SetWriteDeadline(time.Now().Add(-1 * time.Second)) + n, err := p.Write([]byte{'h'}) + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("missing write timeout got err: %v", err) + } + if n != 0 { + t.Errorf("n=%d on timeout", n) + } + }) + t.Run("read", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.Write([]byte{'h'}) + + p.SetReadDeadline(time.Now().Add(-1 * time.Second)) + b := make([]byte, 1) + n, err := p.Read(b) + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("missing read timeout got err: %v", err) + } + if n != 0 { + t.Errorf("n=%d on timeout", n) + } + }) + t.Run("block-write", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.SetWriteDeadline(time.Now().Add(10 * time.Millisecond)) + if err := p.Block(); err != nil { + t.Fatal(err) + } + if _, err := p.Write([]byte{'h'}); !errors.Is(err, os.ErrDeadlineExceeded) { + t.Fatalf("want write timeout got: %v", err) + } + }) + t.Run("block-read", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.Write([]byte{'h', 'i'}) + p.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + b := make([]byte, 1) + if err := p.Block(); err != nil { + t.Fatal(err) + } + if _, err := p.Read(b); !errors.Is(err, os.ErrDeadlineExceeded) { + t.Fatalf("want read timeout got: %v", err) + } + }) +} + +func TestLimit(t *testing.T) { + p := NewPipe("p1", 1) + errCh := make(chan error) + go func() { + n, err := p.Write([]byte{'a', 'b', 'c'}) + if err != nil { + errCh <- err + } else if n != 3 { + errCh <- fmt.Errorf("p.Write n=%d, want 3", n) + } else { + errCh <- nil + } + }() + b := make([]byte, 3) + + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Errorf("Read(%q): n=%d want 1", string(b), n) + } + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Errorf("Read(%q): n=%d want 1", string(b), n) + } + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Errorf("Read(%q): n=%d want 1", string(b), n) + } + + if err := <-errCh; err != nil { + t.Error(err) + } +} diff --git a/net/netaddr/netaddr.go b/net/netaddr/netaddr.go index 6f85a52b7..1ab6c053a 100644 --- a/net/netaddr/netaddr.go +++ b/net/netaddr/netaddr.go @@ -1,49 +1,49 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netaddr is a transitional package while we finish migrating from inet.af/netaddr -// to Go 1.18's net/netip. -// -// TODO(bradfitz): delete this package eventually. Tracking bug is -// https://github.com/tailscale/tailscale/issues/5162 -package netaddr - -import ( - "net" - "net/netip" -) - -// IPv4 returns the IP of the IPv4 address a.b.c.d. -func IPv4(a, b, c, d uint8) netip.Addr { - return netip.AddrFrom4([4]byte{a, b, c, d}) -} - -// Unmap returns the provided AddrPort with its Addr IP component Unmap'ed. -// -// See https://github.com/golang/go/issues/53607#issuecomment-1203466984 -func Unmap(ap netip.AddrPort) netip.AddrPort { - return netip.AddrPortFrom(ap.Addr().Unmap(), ap.Port()) -} - -// FromStdIPNet returns an IPPrefix from the standard library's IPNet type. -// If std is invalid, ok is false. -func FromStdIPNet(std *net.IPNet) (prefix netip.Prefix, ok bool) { - ip, ok := netip.AddrFromSlice(std.IP) - if !ok { - return netip.Prefix{}, false - } - ip = ip.Unmap() - - if l := len(std.Mask); l != net.IPv4len && l != net.IPv6len { - // Invalid mask. - return netip.Prefix{}, false - } - - ones, bits := std.Mask.Size() - if ones == 0 && bits == 0 { - // IPPrefix does not support non-contiguous masks. - return netip.Prefix{}, false - } - - return netip.PrefixFrom(ip, ones), true -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netaddr is a transitional package while we finish migrating from inet.af/netaddr +// to Go 1.18's net/netip. +// +// TODO(bradfitz): delete this package eventually. Tracking bug is +// https://github.com/tailscale/tailscale/issues/5162 +package netaddr + +import ( + "net" + "net/netip" +) + +// IPv4 returns the IP of the IPv4 address a.b.c.d. +func IPv4(a, b, c, d uint8) netip.Addr { + return netip.AddrFrom4([4]byte{a, b, c, d}) +} + +// Unmap returns the provided AddrPort with its Addr IP component Unmap'ed. +// +// See https://github.com/golang/go/issues/53607#issuecomment-1203466984 +func Unmap(ap netip.AddrPort) netip.AddrPort { + return netip.AddrPortFrom(ap.Addr().Unmap(), ap.Port()) +} + +// FromStdIPNet returns an IPPrefix from the standard library's IPNet type. +// If std is invalid, ok is false. +func FromStdIPNet(std *net.IPNet) (prefix netip.Prefix, ok bool) { + ip, ok := netip.AddrFromSlice(std.IP) + if !ok { + return netip.Prefix{}, false + } + ip = ip.Unmap() + + if l := len(std.Mask); l != net.IPv4len && l != net.IPv6len { + // Invalid mask. + return netip.Prefix{}, false + } + + ones, bits := std.Mask.Size() + if ones == 0 && bits == 0 { + // IPPrefix does not support non-contiguous masks. + return netip.Prefix{}, false + } + + return netip.PrefixFrom(ip, ones), true +} diff --git a/net/neterror/neterror.go b/net/neterror/neterror.go index f570b8930..e2387440d 100644 --- a/net/neterror/neterror.go +++ b/net/neterror/neterror.go @@ -1,82 +1,82 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package neterror classifies network errors. -package neterror - -import ( - "errors" - "fmt" - "runtime" - "syscall" -) - -var errEPERM error = syscall.EPERM // box it into interface just once - -// TreatAsLostUDP reports whether err is an error from a UDP send -// operation that should be treated as a UDP packet that just got -// lost. -// -// Notably, on Linux this reports true for EPERM errors (from outbound -// firewall blocks) which aren't really send errors; they're just -// sends that are never going to make it because the local OS blocked -// it. -func TreatAsLostUDP(err error) bool { - if err == nil { - return false - } - switch runtime.GOOS { - case "linux": - // Linux, while not documented in the man page, - // returns EPERM when there's an OUTPUT rule with -j - // DROP or -j REJECT. We use this very specific - // Linux+EPERM check rather than something super broad - // like net.Error.Temporary which could be anything. - // - // For now we only do this on Linux, as such outgoing - // firewall violations mapping to syscall errors - // hasn't yet been observed on other OSes. - return errors.Is(err, errEPERM) - } - return false -} - -var packetWasTruncated func(error) bool // non-nil on Windows at least - -// PacketWasTruncated reports whether err indicates truncation but the RecvFrom -// that generated err was otherwise successful. On Windows, Go's UDP RecvFrom -// calls WSARecvFrom which returns the WSAEMSGSIZE error code when the received -// datagram is larger than the provided buffer. When that happens, both a valid -// size and an error are returned (as per the partial fix for golang/go#14074). -// If the WSAEMSGSIZE error is returned, then we ignore the error to get -// semantics similar to the POSIX operating systems. One caveat is that it -// appears that the source address is not returned when WSAEMSGSIZE occurs, but -// we do not currently look at the source address. -func PacketWasTruncated(err error) bool { - if packetWasTruncated == nil { - return false - } - return packetWasTruncated(err) -} - -var shouldDisableUDPGSO func(error) bool // non-nil on Linux - -func ShouldDisableUDPGSO(err error) bool { - if shouldDisableUDPGSO == nil { - return false - } - return shouldDisableUDPGSO(err) -} - -type ErrUDPGSODisabled struct { - OnLaddr string - RetryErr error -} - -func (e ErrUDPGSODisabled) Error() string { - return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.OnLaddr) -} - -func (e ErrUDPGSODisabled) Unwrap() error { - return e.RetryErr -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package neterror classifies network errors. +package neterror + +import ( + "errors" + "fmt" + "runtime" + "syscall" +) + +var errEPERM error = syscall.EPERM // box it into interface just once + +// TreatAsLostUDP reports whether err is an error from a UDP send +// operation that should be treated as a UDP packet that just got +// lost. +// +// Notably, on Linux this reports true for EPERM errors (from outbound +// firewall blocks) which aren't really send errors; they're just +// sends that are never going to make it because the local OS blocked +// it. +func TreatAsLostUDP(err error) bool { + if err == nil { + return false + } + switch runtime.GOOS { + case "linux": + // Linux, while not documented in the man page, + // returns EPERM when there's an OUTPUT rule with -j + // DROP or -j REJECT. We use this very specific + // Linux+EPERM check rather than something super broad + // like net.Error.Temporary which could be anything. + // + // For now we only do this on Linux, as such outgoing + // firewall violations mapping to syscall errors + // hasn't yet been observed on other OSes. + return errors.Is(err, errEPERM) + } + return false +} + +var packetWasTruncated func(error) bool // non-nil on Windows at least + +// PacketWasTruncated reports whether err indicates truncation but the RecvFrom +// that generated err was otherwise successful. On Windows, Go's UDP RecvFrom +// calls WSARecvFrom which returns the WSAEMSGSIZE error code when the received +// datagram is larger than the provided buffer. When that happens, both a valid +// size and an error are returned (as per the partial fix for golang/go#14074). +// If the WSAEMSGSIZE error is returned, then we ignore the error to get +// semantics similar to the POSIX operating systems. One caveat is that it +// appears that the source address is not returned when WSAEMSGSIZE occurs, but +// we do not currently look at the source address. +func PacketWasTruncated(err error) bool { + if packetWasTruncated == nil { + return false + } + return packetWasTruncated(err) +} + +var shouldDisableUDPGSO func(error) bool // non-nil on Linux + +func ShouldDisableUDPGSO(err error) bool { + if shouldDisableUDPGSO == nil { + return false + } + return shouldDisableUDPGSO(err) +} + +type ErrUDPGSODisabled struct { + OnLaddr string + RetryErr error +} + +func (e ErrUDPGSODisabled) Error() string { + return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.OnLaddr) +} + +func (e ErrUDPGSODisabled) Unwrap() error { + return e.RetryErr +} diff --git a/net/neterror/neterror_linux.go b/net/neterror/neterror_linux.go index 3f402dd30..857367fe8 100644 --- a/net/neterror/neterror_linux.go +++ b/net/neterror/neterror_linux.go @@ -1,26 +1,26 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package neterror - -import ( - "errors" - "os" - - "golang.org/x/sys/unix" -) - -func init() { - shouldDisableUDPGSO = func(err error) bool { - var serr *os.SyscallError - if errors.As(err, &serr) { - // EIO is returned by udp_send_skb() if the device driver does not - // have tx checksumming enabled, which is a hard requirement of - // UDP_SEGMENT. See: - // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 - // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 - return serr.Err == unix.EIO - } - return false - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package neterror + +import ( + "errors" + "os" + + "golang.org/x/sys/unix" +) + +func init() { + shouldDisableUDPGSO = func(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + // EIO is returned by udp_send_skb() if the device driver does not + // have tx checksumming enabled, which is a hard requirement of + // UDP_SEGMENT. See: + // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 + return serr.Err == unix.EIO + } + return false + } +} diff --git a/net/neterror/neterror_linux_test.go b/net/neterror/neterror_linux_test.go index 1d600d6b6..5b9906074 100644 --- a/net/neterror/neterror_linux_test.go +++ b/net/neterror/neterror_linux_test.go @@ -1,54 +1,54 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package neterror - -import ( - "errors" - "net" - "os" - "syscall" - "testing" -) - -func TestTreatAsLostUDP(t *testing.T) { - tests := []struct { - name string - err error - want bool - }{ - {"nil", nil, false}, - {"non-nil", errors.New("foo"), false}, - {"eperm", syscall.EPERM, true}, - { - name: "operror", - err: &net.OpError{ - Op: "write", - Err: &os.SyscallError{ - Syscall: "sendto", - Err: syscall.EPERM, - }, - }, - want: true, - }, - { - name: "host_unreach", - err: &net.OpError{ - Op: "write", - Err: &os.SyscallError{ - Syscall: "sendto", - Err: syscall.EHOSTUNREACH, - }, - }, - want: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := TreatAsLostUDP(tt.err); got != tt.want { - t.Errorf("got = %v; want %v", got, tt.want) - } - }) - } - -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package neterror + +import ( + "errors" + "net" + "os" + "syscall" + "testing" +) + +func TestTreatAsLostUDP(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"nil", nil, false}, + {"non-nil", errors.New("foo"), false}, + {"eperm", syscall.EPERM, true}, + { + name: "operror", + err: &net.OpError{ + Op: "write", + Err: &os.SyscallError{ + Syscall: "sendto", + Err: syscall.EPERM, + }, + }, + want: true, + }, + { + name: "host_unreach", + err: &net.OpError{ + Op: "write", + Err: &os.SyscallError{ + Syscall: "sendto", + Err: syscall.EHOSTUNREACH, + }, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := TreatAsLostUDP(tt.err); got != tt.want { + t.Errorf("got = %v; want %v", got, tt.want) + } + }) + } + +} diff --git a/net/neterror/neterror_windows.go b/net/neterror/neterror_windows.go index c293ec4a9..bf112f5ed 100644 --- a/net/neterror/neterror_windows.go +++ b/net/neterror/neterror_windows.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package neterror - -import ( - "errors" - - "golang.org/x/sys/windows" -) - -func init() { - packetWasTruncated = func(err error) bool { - return errors.Is(err, windows.WSAEMSGSIZE) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package neterror + +import ( + "errors" + + "golang.org/x/sys/windows" +) + +func init() { + packetWasTruncated = func(err error) bool { + return errors.Is(err, windows.WSAEMSGSIZE) + } +} diff --git a/net/netkernelconf/netkernelconf.go b/net/netkernelconf/netkernelconf.go index 23ec9c5b6..3ea502b37 100644 --- a/net/netkernelconf/netkernelconf.go +++ b/net/netkernelconf/netkernelconf.go @@ -1,5 +1,5 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netkernelconf contains code for checking kernel netdev config. -package netkernelconf +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netkernelconf contains code for checking kernel netdev config. +package netkernelconf diff --git a/net/netknob/netknob.go b/net/netknob/netknob.go index 0b271fc95..53171f424 100644 --- a/net/netknob/netknob.go +++ b/net/netknob/netknob.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netknob has Tailscale network knobs. -package netknob - -import ( - "runtime" - "time" -) - -// PlatformTCPKeepAlive returns the default net.Dialer.KeepAlive -// value for the current runtime.GOOS. -func PlatformTCPKeepAlive() time.Duration { - switch runtime.GOOS { - case "ios", "android": - // Disable TCP keep-alives on mobile platforms. - // See https://github.com/golang/go/issues/48622. - // - // TODO(bradfitz): in 1.17.x, try disabling TCP - // keep-alives on for all platforms. - return -1 - } - - // Otherwise, default to 30 seconds, which is mostly what we - // used to do. In some places we used the zero value, which Go - // defaults to 15 seconds. But 30 seconds is fine. - return 30 * time.Second -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netknob has Tailscale network knobs. +package netknob + +import ( + "runtime" + "time" +) + +// PlatformTCPKeepAlive returns the default net.Dialer.KeepAlive +// value for the current runtime.GOOS. +func PlatformTCPKeepAlive() time.Duration { + switch runtime.GOOS { + case "ios", "android": + // Disable TCP keep-alives on mobile platforms. + // See https://github.com/golang/go/issues/48622. + // + // TODO(bradfitz): in 1.17.x, try disabling TCP + // keep-alives on for all platforms. + return -1 + } + + // Otherwise, default to 30 seconds, which is mostly what we + // used to do. In some places we used the zero value, which Go + // defaults to 15 seconds. But 30 seconds is fine. + return 30 * time.Second +} diff --git a/net/netmon/netmon_darwin_test.go b/net/netmon/netmon_darwin_test.go index 77a212683..84c67cf6f 100644 --- a/net/netmon/netmon_darwin_test.go +++ b/net/netmon/netmon_darwin_test.go @@ -1,27 +1,27 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netmon - -import ( - "encoding/hex" - "strings" - "testing" - - "golang.org/x/net/route" -) - -func TestIssue1416RIB(t *testing.T) { - const ribHex = `32 00 05 10 30 00 00 00 00 00 00 00 04 00 00 00 14 12 04 00 06 03 06 00 65 6e 30 ac 87 a3 19 7f 82 00 00 00 0e 12 00 00 00 00 06 00 91 e0 f0 01 00 00` - rtmMsg, err := hex.DecodeString(strings.ReplaceAll(ribHex, " ", "")) - if err != nil { - t.Fatal(err) - } - msgs, err := route.ParseRIB(route.RIBTypeRoute, rtmMsg) - if err != nil { - t.Logf("ParseRIB: %v", err) - t.Skip("skipping on known failure; see https://github.com/tailscale/tailscale/issues/1416") - t.Fatal(err) - } - t.Logf("Got: %#v", msgs) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netmon + +import ( + "encoding/hex" + "strings" + "testing" + + "golang.org/x/net/route" +) + +func TestIssue1416RIB(t *testing.T) { + const ribHex = `32 00 05 10 30 00 00 00 00 00 00 00 04 00 00 00 14 12 04 00 06 03 06 00 65 6e 30 ac 87 a3 19 7f 82 00 00 00 0e 12 00 00 00 00 06 00 91 e0 f0 01 00 00` + rtmMsg, err := hex.DecodeString(strings.ReplaceAll(ribHex, " ", "")) + if err != nil { + t.Fatal(err) + } + msgs, err := route.ParseRIB(route.RIBTypeRoute, rtmMsg) + if err != nil { + t.Logf("ParseRIB: %v", err) + t.Skip("skipping on known failure; see https://github.com/tailscale/tailscale/issues/1416") + t.Fatal(err) + } + t.Logf("Got: %#v", msgs) +} diff --git a/net/netmon/netmon_freebsd.go b/net/netmon/netmon_freebsd.go index 724f964c9..30480a1d3 100644 --- a/net/netmon/netmon_freebsd.go +++ b/net/netmon/netmon_freebsd.go @@ -1,56 +1,56 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netmon - -import ( - "bufio" - "fmt" - "net" - "strings" - - "tailscale.com/types/logger" -) - -// unspecifiedMessage is a minimal message implementation that should not -// be ignored. In general, OS-specific implementations should use better -// types and avoid this if they can. -type unspecifiedMessage struct{} - -func (unspecifiedMessage) ignore() bool { return false } - -// devdConn implements osMon using devd(8). -type devdConn struct { - conn net.Conn -} - -func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { - conn, err := net.Dial("unixpacket", "/var/run/devd.seqpacket.pipe") - if err != nil { - logf("devd dial error: %v, falling back to polling method", err) - return newPollingMon(logf, m) - } - return &devdConn{conn}, nil -} - -func (c *devdConn) IsInterestingInterface(iface string) bool { return true } - -func (c *devdConn) Close() error { - return c.conn.Close() -} - -func (c *devdConn) Receive() (message, error) { - for { - msg, err := bufio.NewReader(c.conn).ReadString('\n') - if err != nil { - return nil, fmt.Errorf("reading devd socket: %v", err) - } - // Only return messages related to the network subsystem. - if !strings.Contains(msg, "system=IFNET") { - continue - } - // TODO: this is where the devd-specific message would - // get converted into a "standard" event message and returned. - return unspecifiedMessage{}, nil - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netmon + +import ( + "bufio" + "fmt" + "net" + "strings" + + "tailscale.com/types/logger" +) + +// unspecifiedMessage is a minimal message implementation that should not +// be ignored. In general, OS-specific implementations should use better +// types and avoid this if they can. +type unspecifiedMessage struct{} + +func (unspecifiedMessage) ignore() bool { return false } + +// devdConn implements osMon using devd(8). +type devdConn struct { + conn net.Conn +} + +func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { + conn, err := net.Dial("unixpacket", "/var/run/devd.seqpacket.pipe") + if err != nil { + logf("devd dial error: %v, falling back to polling method", err) + return newPollingMon(logf, m) + } + return &devdConn{conn}, nil +} + +func (c *devdConn) IsInterestingInterface(iface string) bool { return true } + +func (c *devdConn) Close() error { + return c.conn.Close() +} + +func (c *devdConn) Receive() (message, error) { + for { + msg, err := bufio.NewReader(c.conn).ReadString('\n') + if err != nil { + return nil, fmt.Errorf("reading devd socket: %v", err) + } + // Only return messages related to the network subsystem. + if !strings.Contains(msg, "system=IFNET") { + continue + } + // TODO: this is where the devd-specific message would + // get converted into a "standard" event message and returned. + return unspecifiedMessage{}, nil + } +} diff --git a/net/netmon/netmon_linux.go b/net/netmon/netmon_linux.go index 888afa92d..dd23dd342 100644 --- a/net/netmon/netmon_linux.go +++ b/net/netmon/netmon_linux.go @@ -1,290 +1,290 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !android - -package netmon - -import ( - "net" - "net/netip" - "time" - - "github.com/jsimonetti/rtnetlink" - "github.com/mdlayher/netlink" - "golang.org/x/sys/unix" - "tailscale.com/envknob" - "tailscale.com/net/tsaddr" - "tailscale.com/types/logger" -) - -var debugNetlinkMessages = envknob.RegisterBool("TS_DEBUG_NETLINK") - -// unspecifiedMessage is a minimal message implementation that should not -// be ignored. In general, OS-specific implementations should use better -// types and avoid this if they can. -type unspecifiedMessage struct{} - -func (unspecifiedMessage) ignore() bool { return false } - -// nlConn wraps a *netlink.Conn and returns a monitor.Message -// instead of a netlink.Message. Currently, messages are discarded, -// but down the line, when messages trigger different logic depending -// on the type of event, this provides the capability of handling -// each architecture-specific message in a generic fashion. -type nlConn struct { - logf logger.Logf - conn *netlink.Conn - buffered []netlink.Message - - // addrCache maps interface indices to a set of addresses, and is - // used to suppress duplicate RTM_NEWADDR messages. It is populated - // by RTM_NEWADDR messages and de-populated by RTM_DELADDR. See - // issue #4282. - addrCache map[uint32]map[netip.Addr]bool -} - -func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { - conn, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ - // Routes get us most of the events of interest, but we need - // address as well to cover things like DHCP deciding to give - // us a new address upon renewal - routing wouldn't change, - // but all reachability would. - Groups: unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR | - unix.RTMGRP_IPV4_ROUTE | unix.RTMGRP_IPV6_ROUTE | - unix.RTMGRP_IPV4_RULE, // no IPV6_RULE in x/sys/unix - }) - if err != nil { - // Google Cloud Run does not implement NETLINK_ROUTE RTMGRP support - logf("monitor_linux: AF_NETLINK RTMGRP failed, falling back to polling") - return newPollingMon(logf, m) - } - return &nlConn{logf: logf, conn: conn, addrCache: make(map[uint32]map[netip.Addr]bool)}, nil -} - -func (c *nlConn) IsInterestingInterface(iface string) bool { return true } - -func (c *nlConn) Close() error { return c.conn.Close() } - -func (c *nlConn) Receive() (message, error) { - if len(c.buffered) == 0 { - var err error - c.buffered, err = c.conn.Receive() - if err != nil { - return nil, err - } - if len(c.buffered) == 0 { - // Unexpected. Not seen in wild, but sleep defensively. - time.Sleep(time.Second) - return ignoreMessage{}, nil - } - } - msg := c.buffered[0] - c.buffered = c.buffered[1:] - - // See https://github.com/torvalds/linux/blob/master/include/uapi/linux/rtnetlink.h - // And https://man7.org/linux/man-pages/man7/rtnetlink.7.html - switch msg.Header.Type { - case unix.RTM_NEWADDR, unix.RTM_DELADDR: - var rmsg rtnetlink.AddressMessage - if err := rmsg.UnmarshalBinary(msg.Data); err != nil { - c.logf("failed to parse type %v: %v", msg.Header.Type, err) - return unspecifiedMessage{}, nil - } - - nip := netaddrIP(rmsg.Attributes.Address) - - if debugNetlinkMessages() { - typ := "RTM_NEWADDR" - if msg.Header.Type == unix.RTM_DELADDR { - typ = "RTM_DELADDR" - } - - // label attributes are seemingly only populated for IPv4 addresses in the wild. - label := rmsg.Attributes.Label - if label == "" { - itf, err := net.InterfaceByIndex(int(rmsg.Index)) - if err == nil { - label = itf.Name - } - } - - c.logf("%s: %s(%d) %s / %s", typ, label, rmsg.Index, rmsg.Attributes.Address, rmsg.Attributes.Local) - } - - addrs := c.addrCache[rmsg.Index] - - // Ignore duplicate RTM_NEWADDR messages using c.addrCache to - // detect them. See nlConn.addrcache and issue #4282. - if msg.Header.Type == unix.RTM_NEWADDR { - if addrs == nil { - addrs = make(map[netip.Addr]bool) - c.addrCache[rmsg.Index] = addrs - } - - if addrs[nip] { - if debugNetlinkMessages() { - c.logf("ignored duplicate RTM_NEWADDR for %s", nip) - } - return ignoreMessage{}, nil - } - - addrs[nip] = true - } else { // msg.Header.Type == unix.RTM_DELADDR - if addrs != nil { - delete(addrs, nip) - } - - if len(addrs) == 0 { - delete(c.addrCache, rmsg.Index) - } - } - - nam := &newAddrMessage{ - IfIndex: rmsg.Index, - Addr: nip, - Delete: msg.Header.Type == unix.RTM_DELADDR, - } - if debugNetlinkMessages() { - c.logf("%+v", nam) - } - return nam, nil - case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: - typeStr := "RTM_NEWROUTE" - if msg.Header.Type == unix.RTM_DELROUTE { - typeStr = "RTM_DELROUTE" - } - var rmsg rtnetlink.RouteMessage - if err := rmsg.UnmarshalBinary(msg.Data); err != nil { - c.logf("%s: failed to parse: %v", typeStr, err) - return unspecifiedMessage{}, nil - } - src := netaddrIPPrefix(rmsg.Attributes.Src, rmsg.SrcLength) - dst := netaddrIPPrefix(rmsg.Attributes.Dst, rmsg.DstLength) - gw := netaddrIP(rmsg.Attributes.Gateway) - - if msg.Header.Type == unix.RTM_NEWROUTE && - (rmsg.Attributes.Table == 255 || rmsg.Attributes.Table == 254) && - (dst.Addr().IsMulticast() || dst.Addr().IsLinkLocalUnicast()) { - - if debugNetlinkMessages() { - c.logf("%s ignored", typeStr) - } - - // Normal Linux route changes on new interface coming up; don't log or react. - return ignoreMessage{}, nil - } - - if rmsg.Table == tsTable && dst.IsSingleIP() { - // Don't log. Spammy and normal to see a bunch of these on start-up, - // which we make ourselves. - } else if tsaddr.IsTailscaleIP(dst.Addr()) { - // Verbose only. - c.logf("%s: [v1] src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr, - condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw), - rmsg.Attributes.OutIface, rmsg.Attributes.Table) - } else { - c.logf("%s: src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr, - condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw), - rmsg.Attributes.OutIface, rmsg.Attributes.Table) - } - if msg.Header.Type == unix.RTM_DELROUTE { - // Just logging it for now. - // (Debugging https://github.com/tailscale/tailscale/issues/643) - return unspecifiedMessage{}, nil - } - - nrm := &newRouteMessage{ - Table: rmsg.Table, - Src: src, - Dst: dst, - Gateway: gw, - } - if debugNetlinkMessages() { - c.logf("%+v", nrm) - } - return nrm, nil - case unix.RTM_NEWRULE: - // Probably ourselves adding it. - return ignoreMessage{}, nil - case unix.RTM_DELRULE: - // For https://github.com/tailscale/tailscale/issues/1591 where - // systemd-networkd deletes our rules. - var rmsg rtnetlink.RouteMessage - err := rmsg.UnmarshalBinary(msg.Data) - if err != nil { - c.logf("ip rule deleted; failed to parse netlink message: %v", err) - } else { - c.logf("ip rule deleted: %+v", rmsg) - // On `ip -4 rule del pref 5210 table main`, logs: - // monitor: ip rule deleted: {Family:2 DstLength:0 SrcLength:0 Tos:0 Table:254 Protocol:0 Scope:0 Type:1 Flags:0 Attributes:{Dst: Src: Gateway: OutIface:0 Priority:5210 Table:254 Mark:4294967295 Expires: Metrics: Multipath:[]}} - } - rdm := ipRuleDeletedMessage{ - table: rmsg.Table, - priority: rmsg.Attributes.Priority, - } - if debugNetlinkMessages() { - c.logf("%+v", rdm) - } - return rdm, nil - case unix.RTM_NEWLINK, unix.RTM_DELLINK: - // This is an unhandled message, but don't print an error. - // See https://github.com/tailscale/tailscale/issues/6806 - return unspecifiedMessage{}, nil - default: - c.logf("unhandled netlink msg type %+v, %q", msg.Header, msg.Data) - return unspecifiedMessage{}, nil - } -} - -func netaddrIP(std net.IP) netip.Addr { - ip, _ := netip.AddrFromSlice(std) - return ip.Unmap() -} - -func netaddrIPPrefix(std net.IP, bits uint8) netip.Prefix { - ip, _ := netip.AddrFromSlice(std) - return netip.PrefixFrom(ip.Unmap(), int(bits)) -} - -func condNetAddrPrefix(ipp netip.Prefix) string { - if !ipp.Addr().IsValid() { - return "" - } - return ipp.String() -} - -func condNetAddrIP(ip netip.Addr) string { - if !ip.IsValid() { - return "" - } - return ip.String() -} - -// newRouteMessage is a message for a new route being added. -type newRouteMessage struct { - Src, Dst netip.Prefix - Gateway netip.Addr - Table uint8 -} - -const tsTable = 52 - -func (m *newRouteMessage) ignore() bool { - return m.Table == tsTable || tsaddr.IsTailscaleIP(m.Dst.Addr()) -} - -// newAddrMessage is a message for a new address being added. -type newAddrMessage struct { - Delete bool - Addr netip.Addr - IfIndex uint32 // interface index -} - -func (m *newAddrMessage) ignore() bool { - return tsaddr.IsTailscaleIP(m.Addr) -} - -type ignoreMessage struct{} - -func (ignoreMessage) ignore() bool { return true } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !android + +package netmon + +import ( + "net" + "net/netip" + "time" + + "github.com/jsimonetti/rtnetlink" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" + "tailscale.com/envknob" + "tailscale.com/net/tsaddr" + "tailscale.com/types/logger" +) + +var debugNetlinkMessages = envknob.RegisterBool("TS_DEBUG_NETLINK") + +// unspecifiedMessage is a minimal message implementation that should not +// be ignored. In general, OS-specific implementations should use better +// types and avoid this if they can. +type unspecifiedMessage struct{} + +func (unspecifiedMessage) ignore() bool { return false } + +// nlConn wraps a *netlink.Conn and returns a monitor.Message +// instead of a netlink.Message. Currently, messages are discarded, +// but down the line, when messages trigger different logic depending +// on the type of event, this provides the capability of handling +// each architecture-specific message in a generic fashion. +type nlConn struct { + logf logger.Logf + conn *netlink.Conn + buffered []netlink.Message + + // addrCache maps interface indices to a set of addresses, and is + // used to suppress duplicate RTM_NEWADDR messages. It is populated + // by RTM_NEWADDR messages and de-populated by RTM_DELADDR. See + // issue #4282. + addrCache map[uint32]map[netip.Addr]bool +} + +func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { + conn, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ + // Routes get us most of the events of interest, but we need + // address as well to cover things like DHCP deciding to give + // us a new address upon renewal - routing wouldn't change, + // but all reachability would. + Groups: unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR | + unix.RTMGRP_IPV4_ROUTE | unix.RTMGRP_IPV6_ROUTE | + unix.RTMGRP_IPV4_RULE, // no IPV6_RULE in x/sys/unix + }) + if err != nil { + // Google Cloud Run does not implement NETLINK_ROUTE RTMGRP support + logf("monitor_linux: AF_NETLINK RTMGRP failed, falling back to polling") + return newPollingMon(logf, m) + } + return &nlConn{logf: logf, conn: conn, addrCache: make(map[uint32]map[netip.Addr]bool)}, nil +} + +func (c *nlConn) IsInterestingInterface(iface string) bool { return true } + +func (c *nlConn) Close() error { return c.conn.Close() } + +func (c *nlConn) Receive() (message, error) { + if len(c.buffered) == 0 { + var err error + c.buffered, err = c.conn.Receive() + if err != nil { + return nil, err + } + if len(c.buffered) == 0 { + // Unexpected. Not seen in wild, but sleep defensively. + time.Sleep(time.Second) + return ignoreMessage{}, nil + } + } + msg := c.buffered[0] + c.buffered = c.buffered[1:] + + // See https://github.com/torvalds/linux/blob/master/include/uapi/linux/rtnetlink.h + // And https://man7.org/linux/man-pages/man7/rtnetlink.7.html + switch msg.Header.Type { + case unix.RTM_NEWADDR, unix.RTM_DELADDR: + var rmsg rtnetlink.AddressMessage + if err := rmsg.UnmarshalBinary(msg.Data); err != nil { + c.logf("failed to parse type %v: %v", msg.Header.Type, err) + return unspecifiedMessage{}, nil + } + + nip := netaddrIP(rmsg.Attributes.Address) + + if debugNetlinkMessages() { + typ := "RTM_NEWADDR" + if msg.Header.Type == unix.RTM_DELADDR { + typ = "RTM_DELADDR" + } + + // label attributes are seemingly only populated for IPv4 addresses in the wild. + label := rmsg.Attributes.Label + if label == "" { + itf, err := net.InterfaceByIndex(int(rmsg.Index)) + if err == nil { + label = itf.Name + } + } + + c.logf("%s: %s(%d) %s / %s", typ, label, rmsg.Index, rmsg.Attributes.Address, rmsg.Attributes.Local) + } + + addrs := c.addrCache[rmsg.Index] + + // Ignore duplicate RTM_NEWADDR messages using c.addrCache to + // detect them. See nlConn.addrcache and issue #4282. + if msg.Header.Type == unix.RTM_NEWADDR { + if addrs == nil { + addrs = make(map[netip.Addr]bool) + c.addrCache[rmsg.Index] = addrs + } + + if addrs[nip] { + if debugNetlinkMessages() { + c.logf("ignored duplicate RTM_NEWADDR for %s", nip) + } + return ignoreMessage{}, nil + } + + addrs[nip] = true + } else { // msg.Header.Type == unix.RTM_DELADDR + if addrs != nil { + delete(addrs, nip) + } + + if len(addrs) == 0 { + delete(c.addrCache, rmsg.Index) + } + } + + nam := &newAddrMessage{ + IfIndex: rmsg.Index, + Addr: nip, + Delete: msg.Header.Type == unix.RTM_DELADDR, + } + if debugNetlinkMessages() { + c.logf("%+v", nam) + } + return nam, nil + case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: + typeStr := "RTM_NEWROUTE" + if msg.Header.Type == unix.RTM_DELROUTE { + typeStr = "RTM_DELROUTE" + } + var rmsg rtnetlink.RouteMessage + if err := rmsg.UnmarshalBinary(msg.Data); err != nil { + c.logf("%s: failed to parse: %v", typeStr, err) + return unspecifiedMessage{}, nil + } + src := netaddrIPPrefix(rmsg.Attributes.Src, rmsg.SrcLength) + dst := netaddrIPPrefix(rmsg.Attributes.Dst, rmsg.DstLength) + gw := netaddrIP(rmsg.Attributes.Gateway) + + if msg.Header.Type == unix.RTM_NEWROUTE && + (rmsg.Attributes.Table == 255 || rmsg.Attributes.Table == 254) && + (dst.Addr().IsMulticast() || dst.Addr().IsLinkLocalUnicast()) { + + if debugNetlinkMessages() { + c.logf("%s ignored", typeStr) + } + + // Normal Linux route changes on new interface coming up; don't log or react. + return ignoreMessage{}, nil + } + + if rmsg.Table == tsTable && dst.IsSingleIP() { + // Don't log. Spammy and normal to see a bunch of these on start-up, + // which we make ourselves. + } else if tsaddr.IsTailscaleIP(dst.Addr()) { + // Verbose only. + c.logf("%s: [v1] src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr, + condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw), + rmsg.Attributes.OutIface, rmsg.Attributes.Table) + } else { + c.logf("%s: src=%v, dst=%v, gw=%v, outif=%v, table=%v", typeStr, + condNetAddrPrefix(src), condNetAddrPrefix(dst), condNetAddrIP(gw), + rmsg.Attributes.OutIface, rmsg.Attributes.Table) + } + if msg.Header.Type == unix.RTM_DELROUTE { + // Just logging it for now. + // (Debugging https://github.com/tailscale/tailscale/issues/643) + return unspecifiedMessage{}, nil + } + + nrm := &newRouteMessage{ + Table: rmsg.Table, + Src: src, + Dst: dst, + Gateway: gw, + } + if debugNetlinkMessages() { + c.logf("%+v", nrm) + } + return nrm, nil + case unix.RTM_NEWRULE: + // Probably ourselves adding it. + return ignoreMessage{}, nil + case unix.RTM_DELRULE: + // For https://github.com/tailscale/tailscale/issues/1591 where + // systemd-networkd deletes our rules. + var rmsg rtnetlink.RouteMessage + err := rmsg.UnmarshalBinary(msg.Data) + if err != nil { + c.logf("ip rule deleted; failed to parse netlink message: %v", err) + } else { + c.logf("ip rule deleted: %+v", rmsg) + // On `ip -4 rule del pref 5210 table main`, logs: + // monitor: ip rule deleted: {Family:2 DstLength:0 SrcLength:0 Tos:0 Table:254 Protocol:0 Scope:0 Type:1 Flags:0 Attributes:{Dst: Src: Gateway: OutIface:0 Priority:5210 Table:254 Mark:4294967295 Expires: Metrics: Multipath:[]}} + } + rdm := ipRuleDeletedMessage{ + table: rmsg.Table, + priority: rmsg.Attributes.Priority, + } + if debugNetlinkMessages() { + c.logf("%+v", rdm) + } + return rdm, nil + case unix.RTM_NEWLINK, unix.RTM_DELLINK: + // This is an unhandled message, but don't print an error. + // See https://github.com/tailscale/tailscale/issues/6806 + return unspecifiedMessage{}, nil + default: + c.logf("unhandled netlink msg type %+v, %q", msg.Header, msg.Data) + return unspecifiedMessage{}, nil + } +} + +func netaddrIP(std net.IP) netip.Addr { + ip, _ := netip.AddrFromSlice(std) + return ip.Unmap() +} + +func netaddrIPPrefix(std net.IP, bits uint8) netip.Prefix { + ip, _ := netip.AddrFromSlice(std) + return netip.PrefixFrom(ip.Unmap(), int(bits)) +} + +func condNetAddrPrefix(ipp netip.Prefix) string { + if !ipp.Addr().IsValid() { + return "" + } + return ipp.String() +} + +func condNetAddrIP(ip netip.Addr) string { + if !ip.IsValid() { + return "" + } + return ip.String() +} + +// newRouteMessage is a message for a new route being added. +type newRouteMessage struct { + Src, Dst netip.Prefix + Gateway netip.Addr + Table uint8 +} + +const tsTable = 52 + +func (m *newRouteMessage) ignore() bool { + return m.Table == tsTable || tsaddr.IsTailscaleIP(m.Dst.Addr()) +} + +// newAddrMessage is a message for a new address being added. +type newAddrMessage struct { + Delete bool + Addr netip.Addr + IfIndex uint32 // interface index +} + +func (m *newAddrMessage) ignore() bool { + return tsaddr.IsTailscaleIP(m.Addr) +} + +type ignoreMessage struct{} + +func (ignoreMessage) ignore() bool { return true } diff --git a/net/netmon/netmon_polling.go b/net/netmon/netmon_polling.go index 1ce4a51de..3d6f94731 100644 --- a/net/netmon/netmon_polling.go +++ b/net/netmon/netmon_polling.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build (!linux && !freebsd && !windows && !darwin) || android - -package netmon - -import ( - "tailscale.com/types/logger" -) - -func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { - return newPollingMon(logf, m) -} - -// unspecifiedMessage is a minimal message implementation that should not -// be ignored. In general, OS-specific implementations should use better -// types and avoid this if they can. -type unspecifiedMessage struct{} - -func (unspecifiedMessage) ignore() bool { return false } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (!linux && !freebsd && !windows && !darwin) || android + +package netmon + +import ( + "tailscale.com/types/logger" +) + +func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { + return newPollingMon(logf, m) +} + +// unspecifiedMessage is a minimal message implementation that should not +// be ignored. In general, OS-specific implementations should use better +// types and avoid this if they can. +type unspecifiedMessage struct{} + +func (unspecifiedMessage) ignore() bool { return false } diff --git a/net/netmon/polling.go b/net/netmon/polling.go index bb7210b94..ce1618ed6 100644 --- a/net/netmon/polling.go +++ b/net/netmon/polling.go @@ -1,86 +1,86 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !darwin - -package netmon - -import ( - "bytes" - "errors" - "os" - "runtime" - "sync" - "time" - - "tailscale.com/types/logger" -) - -func newPollingMon(logf logger.Logf, m *Monitor) (osMon, error) { - return &pollingMon{ - logf: logf, - m: m, - stop: make(chan struct{}), - }, nil -} - -// pollingMon is a bad but portable implementation of the link monitor -// that works by polling the interface state every 10 seconds, in lieu -// of anything to subscribe to. -type pollingMon struct { - logf logger.Logf - m *Monitor - - closeOnce sync.Once - stop chan struct{} -} - -func (pm *pollingMon) IsInterestingInterface(iface string) bool { - return true -} - -func (pm *pollingMon) Close() error { - pm.closeOnce.Do(func() { - close(pm.stop) - }) - return nil -} - -func (pm *pollingMon) isCloudRun() bool { - // https: //cloud.google.com/run/docs/reference/container-contract#env-vars - if os.Getenv("K_REVISION") == "" || os.Getenv("K_CONFIGURATION") == "" || - os.Getenv("K_SERVICE") == "" || os.Getenv("PORT") == "" { - return false - } - vers, err := os.ReadFile("/proc/version") - if err != nil { - pm.logf("Failed to read /proc/version: %v", err) - return false - } - return string(bytes.TrimSpace(vers)) == "Linux version 4.4.0 #1 SMP Sun Jan 10 15:06:54 PST 2016" -} - -func (pm *pollingMon) Receive() (message, error) { - d := 10 * time.Second - if runtime.GOOS == "android" { - // We'll have Android notify the link monitor to wake up earlier, - // so this can go very slowly there, to save battery. - // https://github.com/tailscale/tailscale/issues/1427 - d = 10 * time.Minute - } else if pm.isCloudRun() { - // Cloud Run routes never change at runtime. the containers are killed within - // 15 minutes by default, set the interval long enough to be effectively infinite. - pm.logf("monitor polling: Cloud Run detected, reduce polling interval to 24h") - d = 24 * time.Hour - } - timer := time.NewTimer(d) - defer timer.Stop() - for { - select { - case <-timer.C: - return unspecifiedMessage{}, nil - case <-pm.stop: - return nil, errors.New("stopped") - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !darwin + +package netmon + +import ( + "bytes" + "errors" + "os" + "runtime" + "sync" + "time" + + "tailscale.com/types/logger" +) + +func newPollingMon(logf logger.Logf, m *Monitor) (osMon, error) { + return &pollingMon{ + logf: logf, + m: m, + stop: make(chan struct{}), + }, nil +} + +// pollingMon is a bad but portable implementation of the link monitor +// that works by polling the interface state every 10 seconds, in lieu +// of anything to subscribe to. +type pollingMon struct { + logf logger.Logf + m *Monitor + + closeOnce sync.Once + stop chan struct{} +} + +func (pm *pollingMon) IsInterestingInterface(iface string) bool { + return true +} + +func (pm *pollingMon) Close() error { + pm.closeOnce.Do(func() { + close(pm.stop) + }) + return nil +} + +func (pm *pollingMon) isCloudRun() bool { + // https: //cloud.google.com/run/docs/reference/container-contract#env-vars + if os.Getenv("K_REVISION") == "" || os.Getenv("K_CONFIGURATION") == "" || + os.Getenv("K_SERVICE") == "" || os.Getenv("PORT") == "" { + return false + } + vers, err := os.ReadFile("/proc/version") + if err != nil { + pm.logf("Failed to read /proc/version: %v", err) + return false + } + return string(bytes.TrimSpace(vers)) == "Linux version 4.4.0 #1 SMP Sun Jan 10 15:06:54 PST 2016" +} + +func (pm *pollingMon) Receive() (message, error) { + d := 10 * time.Second + if runtime.GOOS == "android" { + // We'll have Android notify the link monitor to wake up earlier, + // so this can go very slowly there, to save battery. + // https://github.com/tailscale/tailscale/issues/1427 + d = 10 * time.Minute + } else if pm.isCloudRun() { + // Cloud Run routes never change at runtime. the containers are killed within + // 15 minutes by default, set the interval long enough to be effectively infinite. + pm.logf("monitor polling: Cloud Run detected, reduce polling interval to 24h") + d = 24 * time.Hour + } + timer := time.NewTimer(d) + defer timer.Stop() + for { + select { + case <-timer.C: + return unspecifiedMessage{}, nil + case <-pm.stop: + return nil, errors.New("stopped") + } + } +} diff --git a/net/netns/netns_android.go b/net/netns/netns_android.go index 79084ff11..162e5c79a 100644 --- a/net/netns/netns_android.go +++ b/net/netns/netns_android.go @@ -1,75 +1,75 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build android - -package netns - -import ( - "fmt" - "sync" - "syscall" - - "tailscale.com/net/netmon" - "tailscale.com/types/logger" -) - -var ( - androidProtectFuncMu sync.Mutex - androidProtectFunc func(fd int) error -) - -// UseSocketMark reports whether SO_MARK is in use. Android does not use SO_MARK. -func UseSocketMark() bool { - return false -} - -// SetAndroidProtectFunc register a func that Android provides that JNI calls into -// https://developer.android.com/reference/android/net/VpnService#protect(int) -// which is documented as: -// -// "Protect a socket from VPN connections. After protecting, data sent -// through this socket will go directly to the underlying network, so -// its traffic will not be forwarded through the VPN. This method is -// useful if some connections need to be kept outside of VPN. For -// example, a VPN tunnel should protect itself if its destination is -// covered by VPN routes. Otherwise its outgoing packets will be sent -// back to the VPN interface and cause an infinite loop. This method -// will fail if the application is not prepared or is revoked." -// -// A nil func disables the use the hook. -// -// This indirection is necessary because this is the supported, stable -// interface to use on Android, and doing the sockopts to set the -// fwmark return errors on Android. The actual implementation of -// VpnService.protect ends up doing an IPC to another process on -// Android, asking for the fwmark to be set. -func SetAndroidProtectFunc(f func(fd int) error) { - androidProtectFuncMu.Lock() - defer androidProtectFuncMu.Unlock() - androidProtectFunc = f -} - -func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error { - return controlC -} - -// controlC marks c as necessary to dial in a separate network namespace. -// -// It's intentionally the same signature as net.Dialer.Control -// and net.ListenConfig.Control. -func controlC(network, address string, c syscall.RawConn) error { - var sockErr error - err := c.Control(func(fd uintptr) { - androidProtectFuncMu.Lock() - f := androidProtectFunc - androidProtectFuncMu.Unlock() - if f != nil { - sockErr = f(int(fd)) - } - }) - if err != nil { - return fmt.Errorf("RawConn.Control on %T: %w", c, err) - } - return sockErr -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build android + +package netns + +import ( + "fmt" + "sync" + "syscall" + + "tailscale.com/net/netmon" + "tailscale.com/types/logger" +) + +var ( + androidProtectFuncMu sync.Mutex + androidProtectFunc func(fd int) error +) + +// UseSocketMark reports whether SO_MARK is in use. Android does not use SO_MARK. +func UseSocketMark() bool { + return false +} + +// SetAndroidProtectFunc register a func that Android provides that JNI calls into +// https://developer.android.com/reference/android/net/VpnService#protect(int) +// which is documented as: +// +// "Protect a socket from VPN connections. After protecting, data sent +// through this socket will go directly to the underlying network, so +// its traffic will not be forwarded through the VPN. This method is +// useful if some connections need to be kept outside of VPN. For +// example, a VPN tunnel should protect itself if its destination is +// covered by VPN routes. Otherwise its outgoing packets will be sent +// back to the VPN interface and cause an infinite loop. This method +// will fail if the application is not prepared or is revoked." +// +// A nil func disables the use the hook. +// +// This indirection is necessary because this is the supported, stable +// interface to use on Android, and doing the sockopts to set the +// fwmark return errors on Android. The actual implementation of +// VpnService.protect ends up doing an IPC to another process on +// Android, asking for the fwmark to be set. +func SetAndroidProtectFunc(f func(fd int) error) { + androidProtectFuncMu.Lock() + defer androidProtectFuncMu.Unlock() + androidProtectFunc = f +} + +func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error { + return controlC +} + +// controlC marks c as necessary to dial in a separate network namespace. +// +// It's intentionally the same signature as net.Dialer.Control +// and net.ListenConfig.Control. +func controlC(network, address string, c syscall.RawConn) error { + var sockErr error + err := c.Control(func(fd uintptr) { + androidProtectFuncMu.Lock() + f := androidProtectFunc + androidProtectFuncMu.Unlock() + if f != nil { + sockErr = f(int(fd)) + } + }) + if err != nil { + return fmt.Errorf("RawConn.Control on %T: %w", c, err) + } + return sockErr +} diff --git a/net/netns/netns_default.go b/net/netns/netns_default.go index 02db19e75..94f24d8fa 100644 --- a/net/netns/netns_default.go +++ b/net/netns/netns_default.go @@ -1,22 +1,22 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux && !windows && !darwin - -package netns - -import ( - "syscall" - - "tailscale.com/net/netmon" - "tailscale.com/types/logger" -) - -func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error { - return controlC -} - -// controlC does nothing to c. -func controlC(network, address string, c syscall.RawConn) error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !windows && !darwin + +package netns + +import ( + "syscall" + + "tailscale.com/net/netmon" + "tailscale.com/types/logger" +) + +func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error { + return controlC +} + +// controlC does nothing to c. +func controlC(network, address string, c syscall.RawConn) error { + return nil +} diff --git a/net/netns/netns_linux_test.go b/net/netns/netns_linux_test.go index cc221bcb1..a5000f37f 100644 --- a/net/netns/netns_linux_test.go +++ b/net/netns/netns_linux_test.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netns - -import ( - "testing" -) - -func TestSocketMarkWorks(t *testing.T) { - _ = socketMarkWorks() - // we cannot actually assert whether the test runner has SO_MARK available - // or not, as we don't know. We're just checking that it doesn't panic. -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netns + +import ( + "testing" +) + +func TestSocketMarkWorks(t *testing.T) { + _ = socketMarkWorks() + // we cannot actually assert whether the test runner has SO_MARK available + // or not, as we don't know. We're just checking that it doesn't panic. +} diff --git a/net/netns/netns_test.go b/net/netns/netns_test.go index 1c6d699ac..82f919b94 100644 --- a/net/netns/netns_test.go +++ b/net/netns/netns_test.go @@ -1,78 +1,78 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netns contains the common code for using the Go net package -// in a logical "network namespace" to avoid routing loops where -// Tailscale-created packets would otherwise loop back through -// Tailscale routes. -// -// Despite the name netns, the exact mechanism used differs by -// operating system, and perhaps even by version of the OS. -// -// The netns package also handles connecting via SOCKS proxies when -// configured by the environment. -package netns - -import ( - "flag" - "testing" -) - -var extNetwork = flag.Bool("use-external-network", false, "use the external network in tests") - -func TestDial(t *testing.T) { - if !*extNetwork { - t.Skip("skipping test without --use-external-network") - } - d := NewDialer(t.Logf, nil) - c, err := d.Dial("tcp", "google.com:80") - if err != nil { - t.Fatal(err) - } - defer c.Close() - t.Logf("got addr %v", c.RemoteAddr()) - - c, err = d.Dial("tcp4", "google.com:80") - if err != nil { - t.Fatal(err) - } - defer c.Close() - t.Logf("got addr %v", c.RemoteAddr()) -} - -func TestIsLocalhost(t *testing.T) { - tests := []struct { - name string - host string - want bool - }{ - {"IPv4 loopback", "127.0.0.1", true}, - {"IPv4 !loopback", "192.168.0.1", false}, - {"IPv4 loopback with port", "127.0.0.1:1", true}, - {"IPv4 !loopback with port", "192.168.0.1:1", false}, - {"IPv4 unspecified", "0.0.0.0", false}, - {"IPv4 unspecified with port", "0.0.0.0:1", false}, - {"IPv6 loopback", "::1", true}, - {"IPv6 !loopback", "2001:4860:4860::8888", false}, - {"IPv6 loopback with port", "[::1]:1", true}, - {"IPv6 !loopback with port", "[2001:4860:4860::8888]:1", false}, - {"IPv6 unspecified", "::", false}, - {"IPv6 unspecified with port", "[::]:1", false}, - {"empty", "", false}, - {"hostname", "example.com", false}, - {"localhost", "localhost", true}, - {"localhost6", "localhost6", true}, - {"localhost with port", "localhost:1", true}, - {"localhost6 with port", "localhost6:1", true}, - {"ip6-localhost", "ip6-localhost", true}, - {"ip6-localhost with port", "ip6-localhost:1", true}, - {"ip6-loopback", "ip6-loopback", true}, - {"ip6-loopback with port", "ip6-loopback:1", true}, - } - - for _, test := range tests { - if got := isLocalhost(test.host); got != test.want { - t.Errorf("isLocalhost(%q) = %v, want %v", test.name, got, test.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netns contains the common code for using the Go net package +// in a logical "network namespace" to avoid routing loops where +// Tailscale-created packets would otherwise loop back through +// Tailscale routes. +// +// Despite the name netns, the exact mechanism used differs by +// operating system, and perhaps even by version of the OS. +// +// The netns package also handles connecting via SOCKS proxies when +// configured by the environment. +package netns + +import ( + "flag" + "testing" +) + +var extNetwork = flag.Bool("use-external-network", false, "use the external network in tests") + +func TestDial(t *testing.T) { + if !*extNetwork { + t.Skip("skipping test without --use-external-network") + } + d := NewDialer(t.Logf, nil) + c, err := d.Dial("tcp", "google.com:80") + if err != nil { + t.Fatal(err) + } + defer c.Close() + t.Logf("got addr %v", c.RemoteAddr()) + + c, err = d.Dial("tcp4", "google.com:80") + if err != nil { + t.Fatal(err) + } + defer c.Close() + t.Logf("got addr %v", c.RemoteAddr()) +} + +func TestIsLocalhost(t *testing.T) { + tests := []struct { + name string + host string + want bool + }{ + {"IPv4 loopback", "127.0.0.1", true}, + {"IPv4 !loopback", "192.168.0.1", false}, + {"IPv4 loopback with port", "127.0.0.1:1", true}, + {"IPv4 !loopback with port", "192.168.0.1:1", false}, + {"IPv4 unspecified", "0.0.0.0", false}, + {"IPv4 unspecified with port", "0.0.0.0:1", false}, + {"IPv6 loopback", "::1", true}, + {"IPv6 !loopback", "2001:4860:4860::8888", false}, + {"IPv6 loopback with port", "[::1]:1", true}, + {"IPv6 !loopback with port", "[2001:4860:4860::8888]:1", false}, + {"IPv6 unspecified", "::", false}, + {"IPv6 unspecified with port", "[::]:1", false}, + {"empty", "", false}, + {"hostname", "example.com", false}, + {"localhost", "localhost", true}, + {"localhost6", "localhost6", true}, + {"localhost with port", "localhost:1", true}, + {"localhost6 with port", "localhost6:1", true}, + {"ip6-localhost", "ip6-localhost", true}, + {"ip6-localhost with port", "ip6-localhost:1", true}, + {"ip6-loopback", "ip6-loopback", true}, + {"ip6-loopback with port", "ip6-loopback:1", true}, + } + + for _, test := range tests { + if got := isLocalhost(test.host); got != test.want { + t.Errorf("isLocalhost(%q) = %v, want %v", test.name, got, test.want) + } + } +} diff --git a/net/netns/socks.go b/net/netns/socks.go index a3d10d3ae..eea69d865 100644 --- a/net/netns/socks.go +++ b/net/netns/socks.go @@ -1,19 +1,19 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !ios && !js - -package netns - -import "golang.org/x/net/proxy" - -func init() { - wrapDialer = wrapSocks -} - -func wrapSocks(d Dialer) Dialer { - if cd, ok := proxy.FromEnvironmentUsing(d).(Dialer); ok { - return cd - } - return d -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !js + +package netns + +import "golang.org/x/net/proxy" + +func init() { + wrapDialer = wrapSocks +} + +func wrapSocks(d Dialer) Dialer { + if cd, ok := proxy.FromEnvironmentUsing(d).(Dialer); ok { + return cd + } + return d +} diff --git a/net/netstat/netstat.go b/net/netstat/netstat.go index 53121dc52..53c7d7757 100644 --- a/net/netstat/netstat.go +++ b/net/netstat/netstat.go @@ -1,35 +1,35 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netstat returns the local machine's network connection table. -package netstat - -import ( - "errors" - "net/netip" - "runtime" -) - -var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) - -type Entry struct { - Local, Remote netip.AddrPort - Pid int - State string // TODO: type? - OSMetadata OSMetadata -} - -// Table contains local machine's TCP connection entries. -// -// Currently only TCP (IPv4 and IPv6) are included. -type Table struct { - Entries []Entry -} - -// Get returns the connection table. -// -// It returns ErrNotImplemented if the table is not available for the -// current operating system. -func Get() (*Table, error) { - return get() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netstat returns the local machine's network connection table. +package netstat + +import ( + "errors" + "net/netip" + "runtime" +) + +var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) + +type Entry struct { + Local, Remote netip.AddrPort + Pid int + State string // TODO: type? + OSMetadata OSMetadata +} + +// Table contains local machine's TCP connection entries. +// +// Currently only TCP (IPv4 and IPv6) are included. +type Table struct { + Entries []Entry +} + +// Get returns the connection table. +// +// It returns ErrNotImplemented if the table is not available for the +// current operating system. +func Get() (*Table, error) { + return get() +} diff --git a/net/netstat/netstat_noimpl.go b/net/netstat/netstat_noimpl.go index 608b1a617..e455c8ce9 100644 --- a/net/netstat/netstat_noimpl.go +++ b/net/netstat/netstat_noimpl.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package netstat - -// OSMetadata includes any additional OS-specific information that may be -// obtained during the retrieval of a given Entry. -type OSMetadata struct{} - -func get() (*Table, error) { - return nil, ErrNotImplemented -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package netstat + +// OSMetadata includes any additional OS-specific information that may be +// obtained during the retrieval of a given Entry. +type OSMetadata struct{} + +func get() (*Table, error) { + return nil, ErrNotImplemented +} diff --git a/net/netstat/netstat_test.go b/net/netstat/netstat_test.go index 74f4fcec0..38827df5e 100644 --- a/net/netstat/netstat_test.go +++ b/net/netstat/netstat_test.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netstat - -import ( - "testing" -) - -func TestGet(t *testing.T) { - nt, err := Get() - if err == ErrNotImplemented { - t.Skip("TODO: not implemented") - } - if err != nil { - t.Fatal(err) - } - for _, e := range nt.Entries { - t.Logf("Entry: %+v", e) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netstat + +import ( + "testing" +) + +func TestGet(t *testing.T) { + nt, err := Get() + if err == ErrNotImplemented { + t.Skip("TODO: not implemented") + } + if err != nil { + t.Fatal(err) + } + for _, e := range nt.Entries { + t.Logf("Entry: %+v", e) + } +} diff --git a/net/packet/doc.go b/net/packet/doc.go index f3cb93db8..ce6c0c307 100644 --- a/net/packet/doc.go +++ b/net/packet/doc.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package packet contains packet parsing and marshaling utilities. -// -// Parsed provides allocation-free minimal packet header decoding, for -// use in packet filtering. The other types in the package are for -// constructing and marshaling packets into []bytes. -// -// To support allocation-free parsing, this package defines IPv4 and -// IPv6 address types. You should prefer to use netaddr's types, -// except where you absolutely need allocation-free IP handling -// (i.e. in the tunnel datapath) and are willing to implement all -// codepaths and data structures twice, once per IP family. -package packet +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package packet contains packet parsing and marshaling utilities. +// +// Parsed provides allocation-free minimal packet header decoding, for +// use in packet filtering. The other types in the package are for +// constructing and marshaling packets into []bytes. +// +// To support allocation-free parsing, this package defines IPv4 and +// IPv6 address types. You should prefer to use netaddr's types, +// except where you absolutely need allocation-free IP handling +// (i.e. in the tunnel datapath) and are willing to implement all +// codepaths and data structures twice, once per IP family. +package packet diff --git a/net/packet/header.go b/net/packet/header.go index 0b1947c0a..dbe84429a 100644 --- a/net/packet/header.go +++ b/net/packet/header.go @@ -1,66 +1,66 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "errors" - "math" -) - -const tcpHeaderLength = 20 -const sctpHeaderLength = 12 - -// maxPacketLength is the largest length that all headers support. -// IPv4 headers using uint16 for this forces an upper bound of 64KB. -const maxPacketLength = math.MaxUint16 - -var ( - // errSmallBuffer is returned when Marshal receives a buffer - // too small to contain the header to marshal. - errSmallBuffer = errors.New("buffer too small") - // errLargePacket is returned when Marshal receives a payload - // larger than the maximum representable size in header - // fields. - errLargePacket = errors.New("packet too large") -) - -// Header is a packet header capable of marshaling itself into a byte -// buffer. -type Header interface { - // Len returns the length of the marshaled packet. - Len() int - // Marshal serializes the header into buf, which must be at - // least Len() bytes long. Implementations of Marshal assume - // that bytes after the first Len() are payload bytes for the - // purpose of computing length and checksum fields. Marshal - // implementations must not allocate memory. - Marshal(buf []byte) error -} - -// HeaderChecksummer is implemented by Header implementations that -// need to do a checksum over their payloads. -type HeaderChecksummer interface { - Header - - // WriteCheck writes the correct checksum into buf, which should - // be be the already-marshalled header and payload. - WriteChecksum(buf []byte) -} - -// Generate generates a new packet with the given Header and -// payload. This function allocates memory, see Header.Marshal for an -// allocation-free option. -func Generate(h Header, payload []byte) []byte { - hlen := h.Len() - buf := make([]byte, hlen+len(payload)) - - copy(buf[hlen:], payload) - h.Marshal(buf) - - if hc, ok := h.(HeaderChecksummer); ok { - hc.WriteChecksum(buf) - } - - return buf -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "errors" + "math" +) + +const tcpHeaderLength = 20 +const sctpHeaderLength = 12 + +// maxPacketLength is the largest length that all headers support. +// IPv4 headers using uint16 for this forces an upper bound of 64KB. +const maxPacketLength = math.MaxUint16 + +var ( + // errSmallBuffer is returned when Marshal receives a buffer + // too small to contain the header to marshal. + errSmallBuffer = errors.New("buffer too small") + // errLargePacket is returned when Marshal receives a payload + // larger than the maximum representable size in header + // fields. + errLargePacket = errors.New("packet too large") +) + +// Header is a packet header capable of marshaling itself into a byte +// buffer. +type Header interface { + // Len returns the length of the marshaled packet. + Len() int + // Marshal serializes the header into buf, which must be at + // least Len() bytes long. Implementations of Marshal assume + // that bytes after the first Len() are payload bytes for the + // purpose of computing length and checksum fields. Marshal + // implementations must not allocate memory. + Marshal(buf []byte) error +} + +// HeaderChecksummer is implemented by Header implementations that +// need to do a checksum over their payloads. +type HeaderChecksummer interface { + Header + + // WriteCheck writes the correct checksum into buf, which should + // be be the already-marshalled header and payload. + WriteChecksum(buf []byte) +} + +// Generate generates a new packet with the given Header and +// payload. This function allocates memory, see Header.Marshal for an +// allocation-free option. +func Generate(h Header, payload []byte) []byte { + hlen := h.Len() + buf := make([]byte, hlen+len(payload)) + + copy(buf[hlen:], payload) + h.Marshal(buf) + + if hc, ok := h.(HeaderChecksummer); ok { + hc.WriteChecksum(buf) + } + + return buf +} diff --git a/net/packet/icmp.go b/net/packet/icmp.go index 7b86edd81..89a7aaa32 100644 --- a/net/packet/icmp.go +++ b/net/packet/icmp.go @@ -1,28 +1,28 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - crand "crypto/rand" - - "encoding/binary" -) - -// ICMPEchoPayload generates a new random ID/Sequence pair, and returns a uint32 -// derived from them, along with the id, sequence and given payload in a buffer. -// It returns an error if the random source could not be read. -func ICMPEchoPayload(payload []byte) (idSeq uint32, buf []byte) { - buf = make([]byte, len(payload)+4) - - // make a completely random id/sequence combo, which is very unlikely to - // collide with a running ping sequence on the host system. Errors are - // ignored, that would result in collisions, but errors reading from the - // random device are rare, and will cause this process universe to soon end. - crand.Read(buf[:4]) - - idSeq = binary.LittleEndian.Uint32(buf) - copy(buf[4:], payload) - - return -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + crand "crypto/rand" + + "encoding/binary" +) + +// ICMPEchoPayload generates a new random ID/Sequence pair, and returns a uint32 +// derived from them, along with the id, sequence and given payload in a buffer. +// It returns an error if the random source could not be read. +func ICMPEchoPayload(payload []byte) (idSeq uint32, buf []byte) { + buf = make([]byte, len(payload)+4) + + // make a completely random id/sequence combo, which is very unlikely to + // collide with a running ping sequence on the host system. Errors are + // ignored, that would result in collisions, but errors reading from the + // random device are rare, and will cause this process universe to soon end. + crand.Read(buf[:4]) + + idSeq = binary.LittleEndian.Uint32(buf) + copy(buf[4:], payload) + + return +} diff --git a/net/packet/icmp6_test.go b/net/packet/icmp6_test.go index c2fab353a..f34883ca4 100644 --- a/net/packet/icmp6_test.go +++ b/net/packet/icmp6_test.go @@ -1,79 +1,79 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "net/netip" - "testing" - - "tailscale.com/types/ipproto" -) - -func TestICMPv6PingResponse(t *testing.T) { - pingHdr := ICMP6Header{ - IP6Header: IP6Header{ - Src: netip.MustParseAddr("1::1"), - Dst: netip.MustParseAddr("2::2"), - IPProto: ipproto.ICMPv6, - }, - Type: ICMP6EchoRequest, - Code: ICMP6NoCode, - } - - // echoReqLen is 2 bytes identifier + 2 bytes seq number. - // https://datatracker.ietf.org/doc/html/rfc4443#section-4.1 - // Packet.IsEchoRequest verifies that these 4 bytes are present. - const echoReqLen = 4 - buf := make([]byte, pingHdr.Len()+echoReqLen) - if err := pingHdr.Marshal(buf); err != nil { - t.Fatal(err) - } - - var p Parsed - p.Decode(buf) - if !p.IsEchoRequest() { - t.Fatalf("not an echo request, got: %+v", p) - } - - pingHdr.ToResponse() - buf = make([]byte, pingHdr.Len()+echoReqLen) - if err := pingHdr.Marshal(buf); err != nil { - t.Fatal(err) - } - - p.Decode(buf) - if p.IsEchoRequest() { - t.Fatalf("unexpectedly still an echo request: %+v", p) - } - if !p.IsEchoResponse() { - t.Fatalf("not an echo response: %+v", p) - } -} - -func TestICMPv6Checksum(t *testing.T) { - const req = "\x60\x0f\x07\x00\x00\x10\x3a\x40\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" + - "\x48\x43\xcd\x96\x62\x7b\x65\x28\x26\x07\xf8\xb0\x40\x0a\x08\x07" + - "\x00\x00\x00\x00\x00\x00\x20\x0e\x80\x00\x4a\x9a\x2e\xea\x00\x02" + - "\x61\xb1\x9e\xad\x00\x06\x45\xaa" - // The packet that we'd originally generated incorrectly, but with the checksum - // bytes fixed per WireShark's correct calculation: - const wantRes = "\x60\x00\xf8\xff\x00\x10\x3a\x40\x26\x07\xf8\xb0\x40\x0a\x08\x07" + - "\x00\x00\x00\x00\x00\x00\x20\x0e\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" + - "\x48\x43\xcd\x96\x62\x7b\x65\x28\x81\x00\x49\x9a\x2e\xea\x00\x02" + - "\x61\xb1\x9e\xad\x00\x06\x45\xaa" - - var p Parsed - p.Decode([]byte(req)) - if !p.IsEchoRequest() { - t.Fatalf("not an echo request, got: %+v", p) - } - - h := p.ICMP6Header() - h.ToResponse() - pong := Generate(&h, p.Payload()) - - if string(pong) != wantRes { - t.Errorf("wrong packet\n\n got: %x\nwant: %x", pong, wantRes) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "net/netip" + "testing" + + "tailscale.com/types/ipproto" +) + +func TestICMPv6PingResponse(t *testing.T) { + pingHdr := ICMP6Header{ + IP6Header: IP6Header{ + Src: netip.MustParseAddr("1::1"), + Dst: netip.MustParseAddr("2::2"), + IPProto: ipproto.ICMPv6, + }, + Type: ICMP6EchoRequest, + Code: ICMP6NoCode, + } + + // echoReqLen is 2 bytes identifier + 2 bytes seq number. + // https://datatracker.ietf.org/doc/html/rfc4443#section-4.1 + // Packet.IsEchoRequest verifies that these 4 bytes are present. + const echoReqLen = 4 + buf := make([]byte, pingHdr.Len()+echoReqLen) + if err := pingHdr.Marshal(buf); err != nil { + t.Fatal(err) + } + + var p Parsed + p.Decode(buf) + if !p.IsEchoRequest() { + t.Fatalf("not an echo request, got: %+v", p) + } + + pingHdr.ToResponse() + buf = make([]byte, pingHdr.Len()+echoReqLen) + if err := pingHdr.Marshal(buf); err != nil { + t.Fatal(err) + } + + p.Decode(buf) + if p.IsEchoRequest() { + t.Fatalf("unexpectedly still an echo request: %+v", p) + } + if !p.IsEchoResponse() { + t.Fatalf("not an echo response: %+v", p) + } +} + +func TestICMPv6Checksum(t *testing.T) { + const req = "\x60\x0f\x07\x00\x00\x10\x3a\x40\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" + + "\x48\x43\xcd\x96\x62\x7b\x65\x28\x26\x07\xf8\xb0\x40\x0a\x08\x07" + + "\x00\x00\x00\x00\x00\x00\x20\x0e\x80\x00\x4a\x9a\x2e\xea\x00\x02" + + "\x61\xb1\x9e\xad\x00\x06\x45\xaa" + // The packet that we'd originally generated incorrectly, but with the checksum + // bytes fixed per WireShark's correct calculation: + const wantRes = "\x60\x00\xf8\xff\x00\x10\x3a\x40\x26\x07\xf8\xb0\x40\x0a\x08\x07" + + "\x00\x00\x00\x00\x00\x00\x20\x0e\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" + + "\x48\x43\xcd\x96\x62\x7b\x65\x28\x81\x00\x49\x9a\x2e\xea\x00\x02" + + "\x61\xb1\x9e\xad\x00\x06\x45\xaa" + + var p Parsed + p.Decode([]byte(req)) + if !p.IsEchoRequest() { + t.Fatalf("not an echo request, got: %+v", p) + } + + h := p.ICMP6Header() + h.ToResponse() + pong := Generate(&h, p.Payload()) + + if string(pong) != wantRes { + t.Errorf("wrong packet\n\n got: %x\nwant: %x", pong, wantRes) + } +} diff --git a/net/packet/ip4.go b/net/packet/ip4.go index 596bc766d..967a8dba7 100644 --- a/net/packet/ip4.go +++ b/net/packet/ip4.go @@ -1,116 +1,116 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "encoding/binary" - "errors" - "net/netip" - - "tailscale.com/types/ipproto" -) - -// ip4HeaderLength is the length of an IPv4 header with no IP options. -const ip4HeaderLength = 20 - -// IP4Header represents an IPv4 packet header. -type IP4Header struct { - IPProto ipproto.Proto - IPID uint16 - Src netip.Addr - Dst netip.Addr -} - -// Len implements Header. -func (h IP4Header) Len() int { - return ip4HeaderLength -} - -var errWrongFamily = errors.New("wrong address family for src/dst IP") - -// Marshal implements Header. -func (h IP4Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - if !h.Src.Is4() || !h.Dst.Is4() { - return errWrongFamily - } - - buf[0] = 0x40 | (byte(h.Len() >> 2)) // IPv4 + IHL - buf[1] = 0x00 // DSCP + ECN - binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf))) // Total length - binary.BigEndian.PutUint16(buf[4:6], h.IPID) // ID - binary.BigEndian.PutUint16(buf[6:8], 0) // Flags + fragment offset - buf[8] = 64 // TTL - buf[9] = uint8(h.IPProto) // Inner protocol - // Blank checksum. This is necessary even though we overwrite - // it later, because the checksum computation runs over these - // bytes and expects them to be zero. - binary.BigEndian.PutUint16(buf[10:12], 0) - src := h.Src.As4() - dst := h.Dst.As4() - copy(buf[12:16], src[:]) - copy(buf[16:20], dst[:]) - - binary.BigEndian.PutUint16(buf[10:12], ip4Checksum(buf[0:20])) // Checksum - - return nil -} - -// ToResponse implements Header. -func (h *IP4Header) ToResponse() { - h.Src, h.Dst = h.Dst, h.Src - // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. - h.IPID = ^h.IPID -} - -// ip4Checksum computes an IPv4 checksum, as specified in -// https://tools.ietf.org/html/rfc1071 -func ip4Checksum(b []byte) uint16 { - var ac uint32 - i := 0 - n := len(b) - for n >= 2 { - ac += uint32(binary.BigEndian.Uint16(b[i : i+2])) - n -= 2 - i += 2 - } - if n == 1 { - ac += uint32(b[i]) << 8 - } - for (ac >> 16) > 0 { - ac = (ac >> 16) + (ac & 0xffff) - } - return uint16(^ac) -} - -// ip4PseudoHeaderOffset is the number of bytes by which the IPv4 UDP -// pseudo-header is smaller than the real IPv4 header. -const ip4PseudoHeaderOffset = 8 - -// marshalPseudo serializes h into buf in the "pseudo-header" form -// required when calculating UDP checksums. The pseudo-header starts -// at buf[ip4PseudoHeaderOffset] so as to abut the following UDP -// header, while leaving enough space in buf for a full IPv4 header. -func (h IP4Header) marshalPseudo(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - - length := len(buf) - h.Len() - src, dst := h.Src.As4(), h.Dst.As4() - copy(buf[8:12], src[:]) - copy(buf[12:16], dst[:]) - buf[16] = 0x0 - buf[17] = uint8(h.IPProto) - binary.BigEndian.PutUint16(buf[18:20], uint16(length)) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "encoding/binary" + "errors" + "net/netip" + + "tailscale.com/types/ipproto" +) + +// ip4HeaderLength is the length of an IPv4 header with no IP options. +const ip4HeaderLength = 20 + +// IP4Header represents an IPv4 packet header. +type IP4Header struct { + IPProto ipproto.Proto + IPID uint16 + Src netip.Addr + Dst netip.Addr +} + +// Len implements Header. +func (h IP4Header) Len() int { + return ip4HeaderLength +} + +var errWrongFamily = errors.New("wrong address family for src/dst IP") + +// Marshal implements Header. +func (h IP4Header) Marshal(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + if !h.Src.Is4() || !h.Dst.Is4() { + return errWrongFamily + } + + buf[0] = 0x40 | (byte(h.Len() >> 2)) // IPv4 + IHL + buf[1] = 0x00 // DSCP + ECN + binary.BigEndian.PutUint16(buf[2:4], uint16(len(buf))) // Total length + binary.BigEndian.PutUint16(buf[4:6], h.IPID) // ID + binary.BigEndian.PutUint16(buf[6:8], 0) // Flags + fragment offset + buf[8] = 64 // TTL + buf[9] = uint8(h.IPProto) // Inner protocol + // Blank checksum. This is necessary even though we overwrite + // it later, because the checksum computation runs over these + // bytes and expects them to be zero. + binary.BigEndian.PutUint16(buf[10:12], 0) + src := h.Src.As4() + dst := h.Dst.As4() + copy(buf[12:16], src[:]) + copy(buf[16:20], dst[:]) + + binary.BigEndian.PutUint16(buf[10:12], ip4Checksum(buf[0:20])) // Checksum + + return nil +} + +// ToResponse implements Header. +func (h *IP4Header) ToResponse() { + h.Src, h.Dst = h.Dst, h.Src + // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. + h.IPID = ^h.IPID +} + +// ip4Checksum computes an IPv4 checksum, as specified in +// https://tools.ietf.org/html/rfc1071 +func ip4Checksum(b []byte) uint16 { + var ac uint32 + i := 0 + n := len(b) + for n >= 2 { + ac += uint32(binary.BigEndian.Uint16(b[i : i+2])) + n -= 2 + i += 2 + } + if n == 1 { + ac += uint32(b[i]) << 8 + } + for (ac >> 16) > 0 { + ac = (ac >> 16) + (ac & 0xffff) + } + return uint16(^ac) +} + +// ip4PseudoHeaderOffset is the number of bytes by which the IPv4 UDP +// pseudo-header is smaller than the real IPv4 header. +const ip4PseudoHeaderOffset = 8 + +// marshalPseudo serializes h into buf in the "pseudo-header" form +// required when calculating UDP checksums. The pseudo-header starts +// at buf[ip4PseudoHeaderOffset] so as to abut the following UDP +// header, while leaving enough space in buf for a full IPv4 header. +func (h IP4Header) marshalPseudo(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + + length := len(buf) - h.Len() + src, dst := h.Src.As4(), h.Dst.As4() + copy(buf[8:12], src[:]) + copy(buf[12:16], dst[:]) + buf[16] = 0x0 + buf[17] = uint8(h.IPProto) + binary.BigEndian.PutUint16(buf[18:20], uint16(length)) + return nil +} diff --git a/net/packet/ip6.go b/net/packet/ip6.go index cebc46c53..d26b9a161 100644 --- a/net/packet/ip6.go +++ b/net/packet/ip6.go @@ -1,76 +1,76 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "encoding/binary" - "net/netip" - - "tailscale.com/types/ipproto" -) - -// ip6HeaderLength is the length of an IPv6 header with no IP options. -const ip6HeaderLength = 40 - -// IP6Header represents an IPv6 packet header. -type IP6Header struct { - IPProto ipproto.Proto - IPID uint32 // only lower 20 bits used - Src netip.Addr - Dst netip.Addr -} - -// Len implements Header. -func (h IP6Header) Len() int { - return ip6HeaderLength -} - -// Marshal implements Header. -func (h IP6Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - - binary.BigEndian.PutUint32(buf[:4], h.IPID&0x000FFFFF) - buf[0] = 0x60 - binary.BigEndian.PutUint16(buf[4:6], uint16(len(buf)-ip6HeaderLength)) // Total length - buf[6] = uint8(h.IPProto) // Inner protocol - buf[7] = 64 // TTL - src, dst := h.Src.As16(), h.Dst.As16() - copy(buf[8:24], src[:]) - copy(buf[24:40], dst[:]) - - return nil -} - -// ToResponse implements Header. -func (h *IP6Header) ToResponse() { - h.Src, h.Dst = h.Dst, h.Src - // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. - h.IPID = (^h.IPID) & 0x000FFFFF -} - -// marshalPseudo serializes h into buf in the "pseudo-header" form -// required when calculating UDP checksums. -func (h IP6Header) marshalPseudo(buf []byte, proto ipproto.Proto) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - - src, dst := h.Src.As16(), h.Dst.As16() - copy(buf[:16], src[:]) - copy(buf[16:32], dst[:]) - binary.BigEndian.PutUint32(buf[32:36], uint32(len(buf)-h.Len())) - buf[36] = 0 - buf[37] = 0 - buf[38] = 0 - buf[39] = byte(proto) // NextProto - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "encoding/binary" + "net/netip" + + "tailscale.com/types/ipproto" +) + +// ip6HeaderLength is the length of an IPv6 header with no IP options. +const ip6HeaderLength = 40 + +// IP6Header represents an IPv6 packet header. +type IP6Header struct { + IPProto ipproto.Proto + IPID uint32 // only lower 20 bits used + Src netip.Addr + Dst netip.Addr +} + +// Len implements Header. +func (h IP6Header) Len() int { + return ip6HeaderLength +} + +// Marshal implements Header. +func (h IP6Header) Marshal(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + + binary.BigEndian.PutUint32(buf[:4], h.IPID&0x000FFFFF) + buf[0] = 0x60 + binary.BigEndian.PutUint16(buf[4:6], uint16(len(buf)-ip6HeaderLength)) // Total length + buf[6] = uint8(h.IPProto) // Inner protocol + buf[7] = 64 // TTL + src, dst := h.Src.As16(), h.Dst.As16() + copy(buf[8:24], src[:]) + copy(buf[24:40], dst[:]) + + return nil +} + +// ToResponse implements Header. +func (h *IP6Header) ToResponse() { + h.Src, h.Dst = h.Dst, h.Src + // Flip the bits in the IPID. If incoming IPIDs are distinct, so are these. + h.IPID = (^h.IPID) & 0x000FFFFF +} + +// marshalPseudo serializes h into buf in the "pseudo-header" form +// required when calculating UDP checksums. +func (h IP6Header) marshalPseudo(buf []byte, proto ipproto.Proto) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + + src, dst := h.Src.As16(), h.Dst.As16() + copy(buf[:16], src[:]) + copy(buf[16:32], dst[:]) + binary.BigEndian.PutUint32(buf[32:36], uint32(len(buf)-h.Len())) + buf[36] = 0 + buf[37] = 0 + buf[38] = 0 + buf[39] = byte(proto) // NextProto + return nil +} diff --git a/net/packet/tsmp_test.go b/net/packet/tsmp_test.go index 4ec24e1ea..e261e6a41 100644 --- a/net/packet/tsmp_test.go +++ b/net/packet/tsmp_test.go @@ -1,73 +1,73 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "net/netip" - "testing" -) - -func TestTailscaleRejectedHeader(t *testing.T) { - tests := []struct { - h TailscaleRejectedHeader - wantStr string - }{ - { - h: TailscaleRejectedHeader{ - IPSrc: netip.MustParseAddr("5.5.5.5"), - IPDst: netip.MustParseAddr("1.2.3.4"), - Src: netip.MustParseAddrPort("1.2.3.4:567"), - Dst: netip.MustParseAddrPort("5.5.5.5:443"), - Proto: TCP, - Reason: RejectedDueToACLs, - }, - wantStr: "TSMP-reject-flow{TCP 1.2.3.4:567 > 5.5.5.5:443}: acl", - }, - { - h: TailscaleRejectedHeader{ - IPSrc: netip.MustParseAddr("2::2"), - IPDst: netip.MustParseAddr("1::1"), - Src: netip.MustParseAddrPort("[1::1]:567"), - Dst: netip.MustParseAddrPort("[2::2]:443"), - Proto: UDP, - Reason: RejectedDueToShieldsUp, - }, - wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: shields", - }, - { - h: TailscaleRejectedHeader{ - IPSrc: netip.MustParseAddr("2::2"), - IPDst: netip.MustParseAddr("1::1"), - Src: netip.MustParseAddrPort("[1::1]:567"), - Dst: netip.MustParseAddrPort("[2::2]:443"), - Proto: UDP, - Reason: RejectedDueToIPForwarding, - MaybeBroken: true, - }, - wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: host-ip-forwarding-unavailable", - }, - } - for i, tt := range tests { - gotStr := tt.h.String() - if gotStr != tt.wantStr { - t.Errorf("%v. String = %q; want %q", i, gotStr, tt.wantStr) - continue - } - pkt := make([]byte, tt.h.Len()) - tt.h.Marshal(pkt) - - var p Parsed - p.Decode(pkt) - t.Logf("Parsed: %+v", p) - t.Logf("Parsed: %s", p.String()) - back, ok := p.AsTailscaleRejectedHeader() - if !ok { - t.Errorf("%v. %q (%02x) didn't parse back", i, gotStr, pkt) - continue - } - if back != tt.h { - t.Errorf("%v. %q parsed back as %q", i, tt.h, back) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "net/netip" + "testing" +) + +func TestTailscaleRejectedHeader(t *testing.T) { + tests := []struct { + h TailscaleRejectedHeader + wantStr string + }{ + { + h: TailscaleRejectedHeader{ + IPSrc: netip.MustParseAddr("5.5.5.5"), + IPDst: netip.MustParseAddr("1.2.3.4"), + Src: netip.MustParseAddrPort("1.2.3.4:567"), + Dst: netip.MustParseAddrPort("5.5.5.5:443"), + Proto: TCP, + Reason: RejectedDueToACLs, + }, + wantStr: "TSMP-reject-flow{TCP 1.2.3.4:567 > 5.5.5.5:443}: acl", + }, + { + h: TailscaleRejectedHeader{ + IPSrc: netip.MustParseAddr("2::2"), + IPDst: netip.MustParseAddr("1::1"), + Src: netip.MustParseAddrPort("[1::1]:567"), + Dst: netip.MustParseAddrPort("[2::2]:443"), + Proto: UDP, + Reason: RejectedDueToShieldsUp, + }, + wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: shields", + }, + { + h: TailscaleRejectedHeader{ + IPSrc: netip.MustParseAddr("2::2"), + IPDst: netip.MustParseAddr("1::1"), + Src: netip.MustParseAddrPort("[1::1]:567"), + Dst: netip.MustParseAddrPort("[2::2]:443"), + Proto: UDP, + Reason: RejectedDueToIPForwarding, + MaybeBroken: true, + }, + wantStr: "TSMP-reject-flow{UDP [1::1]:567 > [2::2]:443}: host-ip-forwarding-unavailable", + }, + } + for i, tt := range tests { + gotStr := tt.h.String() + if gotStr != tt.wantStr { + t.Errorf("%v. String = %q; want %q", i, gotStr, tt.wantStr) + continue + } + pkt := make([]byte, tt.h.Len()) + tt.h.Marshal(pkt) + + var p Parsed + p.Decode(pkt) + t.Logf("Parsed: %+v", p) + t.Logf("Parsed: %s", p.String()) + back, ok := p.AsTailscaleRejectedHeader() + if !ok { + t.Errorf("%v. %q (%02x) didn't parse back", i, gotStr, pkt) + continue + } + if back != tt.h { + t.Errorf("%v. %q parsed back as %q", i, tt.h, back) + } + } +} diff --git a/net/packet/udp4.go b/net/packet/udp4.go index c8761baef..0d5bca73e 100644 --- a/net/packet/udp4.go +++ b/net/packet/udp4.go @@ -1,58 +1,58 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "encoding/binary" - - "tailscale.com/types/ipproto" -) - -// udpHeaderLength is the size of the UDP packet header, not including -// the outer IP header. -const udpHeaderLength = 8 - -// UDP4Header is an IPv4+UDP header. -type UDP4Header struct { - IP4Header - SrcPort uint16 - DstPort uint16 -} - -// Len implements Header. -func (h UDP4Header) Len() int { - return h.IP4Header.Len() + udpHeaderLength -} - -// Marshal implements Header. -func (h UDP4Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - // The caller does not need to set this. - h.IPProto = ipproto.UDP - - length := len(buf) - h.IP4Header.Len() - binary.BigEndian.PutUint16(buf[20:22], h.SrcPort) - binary.BigEndian.PutUint16(buf[22:24], h.DstPort) - binary.BigEndian.PutUint16(buf[24:26], uint16(length)) - binary.BigEndian.PutUint16(buf[26:28], 0) // blank checksum - - // UDP checksum with IP pseudo header. - h.IP4Header.marshalPseudo(buf) - binary.BigEndian.PutUint16(buf[26:28], ip4Checksum(buf[ip4PseudoHeaderOffset:])) - - h.IP4Header.Marshal(buf) - - return nil -} - -// ToResponse implements Header. -func (h *UDP4Header) ToResponse() { - h.SrcPort, h.DstPort = h.DstPort, h.SrcPort - h.IP4Header.ToResponse() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "encoding/binary" + + "tailscale.com/types/ipproto" +) + +// udpHeaderLength is the size of the UDP packet header, not including +// the outer IP header. +const udpHeaderLength = 8 + +// UDP4Header is an IPv4+UDP header. +type UDP4Header struct { + IP4Header + SrcPort uint16 + DstPort uint16 +} + +// Len implements Header. +func (h UDP4Header) Len() int { + return h.IP4Header.Len() + udpHeaderLength +} + +// Marshal implements Header. +func (h UDP4Header) Marshal(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + // The caller does not need to set this. + h.IPProto = ipproto.UDP + + length := len(buf) - h.IP4Header.Len() + binary.BigEndian.PutUint16(buf[20:22], h.SrcPort) + binary.BigEndian.PutUint16(buf[22:24], h.DstPort) + binary.BigEndian.PutUint16(buf[24:26], uint16(length)) + binary.BigEndian.PutUint16(buf[26:28], 0) // blank checksum + + // UDP checksum with IP pseudo header. + h.IP4Header.marshalPseudo(buf) + binary.BigEndian.PutUint16(buf[26:28], ip4Checksum(buf[ip4PseudoHeaderOffset:])) + + h.IP4Header.Marshal(buf) + + return nil +} + +// ToResponse implements Header. +func (h *UDP4Header) ToResponse() { + h.SrcPort, h.DstPort = h.DstPort, h.SrcPort + h.IP4Header.ToResponse() +} diff --git a/net/packet/udp6.go b/net/packet/udp6.go index c8634b508..10fdcb99e 100644 --- a/net/packet/udp6.go +++ b/net/packet/udp6.go @@ -1,54 +1,54 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package packet - -import ( - "encoding/binary" - - "tailscale.com/types/ipproto" -) - -// UDP6Header is an IPv6+UDP header. -type UDP6Header struct { - IP6Header - SrcPort uint16 - DstPort uint16 -} - -// Len implements Header. -func (h UDP6Header) Len() int { - return h.IP6Header.Len() + udpHeaderLength -} - -// Marshal implements Header. -func (h UDP6Header) Marshal(buf []byte) error { - if len(buf) < h.Len() { - return errSmallBuffer - } - if len(buf) > maxPacketLength { - return errLargePacket - } - // The caller does not need to set this. - h.IPProto = ipproto.UDP - - length := len(buf) - h.IP6Header.Len() - binary.BigEndian.PutUint16(buf[40:42], h.SrcPort) - binary.BigEndian.PutUint16(buf[42:44], h.DstPort) - binary.BigEndian.PutUint16(buf[44:46], uint16(length)) - binary.BigEndian.PutUint16(buf[46:48], 0) // blank checksum - - // UDP checksum with IP pseudo header. - h.IP6Header.marshalPseudo(buf, ipproto.UDP) - binary.BigEndian.PutUint16(buf[46:48], ip4Checksum(buf[:])) - - h.IP6Header.Marshal(buf) - - return nil -} - -// ToResponse implements Header. -func (h *UDP6Header) ToResponse() { - h.SrcPort, h.DstPort = h.DstPort, h.SrcPort - h.IP6Header.ToResponse() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "encoding/binary" + + "tailscale.com/types/ipproto" +) + +// UDP6Header is an IPv6+UDP header. +type UDP6Header struct { + IP6Header + SrcPort uint16 + DstPort uint16 +} + +// Len implements Header. +func (h UDP6Header) Len() int { + return h.IP6Header.Len() + udpHeaderLength +} + +// Marshal implements Header. +func (h UDP6Header) Marshal(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if len(buf) > maxPacketLength { + return errLargePacket + } + // The caller does not need to set this. + h.IPProto = ipproto.UDP + + length := len(buf) - h.IP6Header.Len() + binary.BigEndian.PutUint16(buf[40:42], h.SrcPort) + binary.BigEndian.PutUint16(buf[42:44], h.DstPort) + binary.BigEndian.PutUint16(buf[44:46], uint16(length)) + binary.BigEndian.PutUint16(buf[46:48], 0) // blank checksum + + // UDP checksum with IP pseudo header. + h.IP6Header.marshalPseudo(buf, ipproto.UDP) + binary.BigEndian.PutUint16(buf[46:48], ip4Checksum(buf[:])) + + h.IP6Header.Marshal(buf) + + return nil +} + +// ToResponse implements Header. +func (h *UDP6Header) ToResponse() { + h.SrcPort, h.DstPort = h.DstPort, h.SrcPort + h.IP6Header.ToResponse() +} diff --git a/net/ping/ping.go b/net/ping/ping.go index f2093292a..01f3dcf2c 100644 --- a/net/ping/ping.go +++ b/net/ping/ping.go @@ -1,343 +1,343 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package ping allows sending ICMP echo requests to a host in order to -// determine network latency. -package ping - -import ( - "bytes" - "context" - "crypto/rand" - "encoding/binary" - "fmt" - "io" - "log" - "net" - "net/netip" - "sync" - "sync/atomic" - "time" - - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "tailscale.com/types/logger" - "tailscale.com/util/mak" - "tailscale.com/util/multierr" -) - -const ( - v4Type = "ip4:icmp" - v6Type = "ip6:icmp" -) - -type response struct { - t time.Time - err error -} - -type outstanding struct { - ch chan response - data []byte -} - -// PacketListener defines the interface required to listen to packages -// on an address. -type ListenPacketer interface { - ListenPacket(ctx context.Context, typ string, addr string) (net.PacketConn, error) -} - -// Pinger represents a set of ICMP echo requests to be sent at a single time. -// -// A new instance should be created for each concurrent set of ping requests; -// this type should not be reused. -type Pinger struct { - lp ListenPacketer - - // closed guards against send incrementing the waitgroup concurrently with close. - closed atomic.Bool - Logf logger.Logf - Verbose bool - timeNow func() time.Time - id uint16 // uint16 per RFC 792 - wg sync.WaitGroup - - // Following fields protected by mu - mu sync.Mutex - // conns is a map of "type" to net.PacketConn, type is either - // "ip4:icmp" or "ip6:icmp" - conns map[string]net.PacketConn - seq uint16 // uint16 per RFC 792 - pings map[uint16]outstanding -} - -// New creates a new Pinger. The Context provided will be used to create -// network listeners, and to set an absolute deadline (if any) on the net.Conn -func New(ctx context.Context, logf logger.Logf, lp ListenPacketer) *Pinger { - var id [2]byte - if _, err := io.ReadFull(rand.Reader, id[:]); err != nil { - panic("net/ping: New:" + err.Error()) - } - - return &Pinger{ - lp: lp, - Logf: logf, - timeNow: time.Now, - id: binary.LittleEndian.Uint16(id[:]), - pings: make(map[uint16]outstanding), - } -} - -func (p *Pinger) mkconn(ctx context.Context, typ, addr string) (net.PacketConn, error) { - if p.closed.Load() { - return nil, net.ErrClosed - } - - c, err := p.lp.ListenPacket(ctx, typ, addr) - if err != nil { - return nil, err - } - - // Start by setting the deadline from the context; note that this - // applies to all future I/O, so we only need to do it once. - deadline, ok := ctx.Deadline() - if ok { - if err := c.SetReadDeadline(deadline); err != nil { - return nil, err - } - } - - p.wg.Add(1) - go p.run(ctx, c, typ) - - return c, err -} - -// getConn creates or returns a conn matching typ which is ip4:icmp -// or ip6:icmp. -func (p *Pinger) getConn(ctx context.Context, typ string) (net.PacketConn, error) { - p.mu.Lock() - defer p.mu.Unlock() - if c, ok := p.conns[typ]; ok { - return c, nil - } - - var addr = "0.0.0.0" - if typ == v6Type { - addr = "::" - } - c, err := p.mkconn(ctx, typ, addr) - if err != nil { - return nil, err - } - mak.Set(&p.conns, typ, c) - return c, nil -} - -func (p *Pinger) logf(format string, a ...any) { - if p.Logf != nil { - p.Logf(format, a...) - } else { - log.Printf(format, a...) - } -} - -func (p *Pinger) vlogf(format string, a ...any) { - if p.Verbose { - p.logf(format, a...) - } -} - -func (p *Pinger) Close() error { - p.closed.Store(true) - - p.mu.Lock() - conns := p.conns - p.conns = nil - p.mu.Unlock() - - var errors []error - for _, c := range conns { - if err := c.Close(); err != nil { - errors = append(errors, err) - } - } - - p.wg.Wait() - p.cleanupOutstanding() - - return multierr.New(errors...) -} - -func (p *Pinger) run(ctx context.Context, conn net.PacketConn, typ string) { - defer p.wg.Done() - defer func() { - conn.Close() - p.mu.Lock() - delete(p.conns, typ) - p.mu.Unlock() - }() - buf := make([]byte, 1500) - -loop: - for { - select { - case <-ctx.Done(): - break loop - default: - } - - n, _, err := conn.ReadFrom(buf) - if err != nil { - // Ignore temporary errors; everything else is fatal - if netErr, ok := err.(net.Error); !ok || !netErr.Temporary() { - break - } - continue - } - - p.handleResponse(buf[:n], p.timeNow(), typ) - } -} - -func (p *Pinger) cleanupOutstanding() { - // Complete outstanding requests - p.mu.Lock() - defer p.mu.Unlock() - for _, o := range p.pings { - o.ch <- response{err: net.ErrClosed} - } -} - -func (p *Pinger) handleResponse(buf []byte, now time.Time, typ string) { - // We need to handle responding to both IPv4 - // and IPv6. - var icmpType icmp.Type - switch typ { - case v4Type: - icmpType = ipv4.ICMPTypeEchoReply - case v6Type: - icmpType = ipv6.ICMPTypeEchoReply - default: - p.vlogf("handleResponse: unknown icmp.Type") - return - } - - m, err := icmp.ParseMessage(icmpType.Protocol(), buf) - if err != nil { - p.vlogf("handleResponse: invalid packet: %v", err) - return - } - - if m.Type != icmpType { - p.vlogf("handleResponse: wanted m.Type=%d; got %d", icmpType, m.Type) - return - } - - resp, ok := m.Body.(*icmp.Echo) - if !ok || resp == nil { - p.vlogf("handleResponse: wanted body=*icmp.Echo; got %v", m.Body) - return - } - - // We assume we sent this if the ID in the response is ours. - if uint16(resp.ID) != p.id { - p.vlogf("handleResponse: wanted ID=%d; got %d", p.id, resp.ID) - return - } - - // Search for existing running echo request - var o outstanding - p.mu.Lock() - if o, ok = p.pings[uint16(resp.Seq)]; ok { - // Ensure that the data matches before we delete from our map, - // so a future correct packet will be handled correctly. - if bytes.Equal(resp.Data, o.data) { - delete(p.pings, uint16(resp.Seq)) - } else { - p.vlogf("handleResponse: got response for Seq %d with mismatched data", resp.Seq) - ok = false - } - } else { - p.vlogf("handleResponse: got response for unknown Seq %d", resp.Seq) - } - p.mu.Unlock() - - if ok { - o.ch <- response{t: now} - } -} - -// Send sends an ICMP Echo Request packet to the destination, waits for a -// response, and returns the duration between when the request was sent and -// when the reply was received. -// -// If provided, "data" is sent with the packet and is compared upon receiving a -// reply. -func (p *Pinger) Send(ctx context.Context, dest net.Addr, data []byte) (time.Duration, error) { - // Use sequential sequence numbers on the assumption that we will not - // wrap around when using a single Pinger instance - p.mu.Lock() - p.seq++ - seq := p.seq - p.mu.Unlock() - - // Check whether the address is IPv4 or IPv6 to - // determine the icmp.Type and conn to use. - var conn net.PacketConn - var icmpType icmp.Type = ipv4.ICMPTypeEcho - ap, err := netip.ParseAddr(dest.String()) - if err != nil { - return 0, err - } - if ap.Is6() { - icmpType = ipv6.ICMPTypeEchoRequest - conn, err = p.getConn(ctx, v6Type) - } else { - conn, err = p.getConn(ctx, v4Type) - } - if err != nil { - return 0, err - } - - m := icmp.Message{ - Type: icmpType, - Code: 0, - Body: &icmp.Echo{ - ID: int(p.id), - Seq: int(seq), - Data: data, - }, - } - b, err := m.Marshal(nil) - if err != nil { - return 0, err - } - - // Register our response before sending since we could otherwise race a - // quick reply. - ch := make(chan response, 1) - p.mu.Lock() - p.pings[seq] = outstanding{ch: ch, data: data} - p.mu.Unlock() - - start := p.timeNow() - n, err := conn.WriteTo(b, dest) - if err != nil { - return 0, err - } else if n != len(b) { - return 0, fmt.Errorf("conn.WriteTo: got %v; want %v", n, len(b)) - } - - select { - case resp := <-ch: - if resp.err != nil { - return 0, resp.err - } - return resp.t.Sub(start), nil - - case <-ctx.Done(): - return 0, ctx.Err() - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ping allows sending ICMP echo requests to a host in order to +// determine network latency. +package ping + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/binary" + "fmt" + "io" + "log" + "net" + "net/netip" + "sync" + "sync/atomic" + "time" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "tailscale.com/types/logger" + "tailscale.com/util/mak" + "tailscale.com/util/multierr" +) + +const ( + v4Type = "ip4:icmp" + v6Type = "ip6:icmp" +) + +type response struct { + t time.Time + err error +} + +type outstanding struct { + ch chan response + data []byte +} + +// PacketListener defines the interface required to listen to packages +// on an address. +type ListenPacketer interface { + ListenPacket(ctx context.Context, typ string, addr string) (net.PacketConn, error) +} + +// Pinger represents a set of ICMP echo requests to be sent at a single time. +// +// A new instance should be created for each concurrent set of ping requests; +// this type should not be reused. +type Pinger struct { + lp ListenPacketer + + // closed guards against send incrementing the waitgroup concurrently with close. + closed atomic.Bool + Logf logger.Logf + Verbose bool + timeNow func() time.Time + id uint16 // uint16 per RFC 792 + wg sync.WaitGroup + + // Following fields protected by mu + mu sync.Mutex + // conns is a map of "type" to net.PacketConn, type is either + // "ip4:icmp" or "ip6:icmp" + conns map[string]net.PacketConn + seq uint16 // uint16 per RFC 792 + pings map[uint16]outstanding +} + +// New creates a new Pinger. The Context provided will be used to create +// network listeners, and to set an absolute deadline (if any) on the net.Conn +func New(ctx context.Context, logf logger.Logf, lp ListenPacketer) *Pinger { + var id [2]byte + if _, err := io.ReadFull(rand.Reader, id[:]); err != nil { + panic("net/ping: New:" + err.Error()) + } + + return &Pinger{ + lp: lp, + Logf: logf, + timeNow: time.Now, + id: binary.LittleEndian.Uint16(id[:]), + pings: make(map[uint16]outstanding), + } +} + +func (p *Pinger) mkconn(ctx context.Context, typ, addr string) (net.PacketConn, error) { + if p.closed.Load() { + return nil, net.ErrClosed + } + + c, err := p.lp.ListenPacket(ctx, typ, addr) + if err != nil { + return nil, err + } + + // Start by setting the deadline from the context; note that this + // applies to all future I/O, so we only need to do it once. + deadline, ok := ctx.Deadline() + if ok { + if err := c.SetReadDeadline(deadline); err != nil { + return nil, err + } + } + + p.wg.Add(1) + go p.run(ctx, c, typ) + + return c, err +} + +// getConn creates or returns a conn matching typ which is ip4:icmp +// or ip6:icmp. +func (p *Pinger) getConn(ctx context.Context, typ string) (net.PacketConn, error) { + p.mu.Lock() + defer p.mu.Unlock() + if c, ok := p.conns[typ]; ok { + return c, nil + } + + var addr = "0.0.0.0" + if typ == v6Type { + addr = "::" + } + c, err := p.mkconn(ctx, typ, addr) + if err != nil { + return nil, err + } + mak.Set(&p.conns, typ, c) + return c, nil +} + +func (p *Pinger) logf(format string, a ...any) { + if p.Logf != nil { + p.Logf(format, a...) + } else { + log.Printf(format, a...) + } +} + +func (p *Pinger) vlogf(format string, a ...any) { + if p.Verbose { + p.logf(format, a...) + } +} + +func (p *Pinger) Close() error { + p.closed.Store(true) + + p.mu.Lock() + conns := p.conns + p.conns = nil + p.mu.Unlock() + + var errors []error + for _, c := range conns { + if err := c.Close(); err != nil { + errors = append(errors, err) + } + } + + p.wg.Wait() + p.cleanupOutstanding() + + return multierr.New(errors...) +} + +func (p *Pinger) run(ctx context.Context, conn net.PacketConn, typ string) { + defer p.wg.Done() + defer func() { + conn.Close() + p.mu.Lock() + delete(p.conns, typ) + p.mu.Unlock() + }() + buf := make([]byte, 1500) + +loop: + for { + select { + case <-ctx.Done(): + break loop + default: + } + + n, _, err := conn.ReadFrom(buf) + if err != nil { + // Ignore temporary errors; everything else is fatal + if netErr, ok := err.(net.Error); !ok || !netErr.Temporary() { + break + } + continue + } + + p.handleResponse(buf[:n], p.timeNow(), typ) + } +} + +func (p *Pinger) cleanupOutstanding() { + // Complete outstanding requests + p.mu.Lock() + defer p.mu.Unlock() + for _, o := range p.pings { + o.ch <- response{err: net.ErrClosed} + } +} + +func (p *Pinger) handleResponse(buf []byte, now time.Time, typ string) { + // We need to handle responding to both IPv4 + // and IPv6. + var icmpType icmp.Type + switch typ { + case v4Type: + icmpType = ipv4.ICMPTypeEchoReply + case v6Type: + icmpType = ipv6.ICMPTypeEchoReply + default: + p.vlogf("handleResponse: unknown icmp.Type") + return + } + + m, err := icmp.ParseMessage(icmpType.Protocol(), buf) + if err != nil { + p.vlogf("handleResponse: invalid packet: %v", err) + return + } + + if m.Type != icmpType { + p.vlogf("handleResponse: wanted m.Type=%d; got %d", icmpType, m.Type) + return + } + + resp, ok := m.Body.(*icmp.Echo) + if !ok || resp == nil { + p.vlogf("handleResponse: wanted body=*icmp.Echo; got %v", m.Body) + return + } + + // We assume we sent this if the ID in the response is ours. + if uint16(resp.ID) != p.id { + p.vlogf("handleResponse: wanted ID=%d; got %d", p.id, resp.ID) + return + } + + // Search for existing running echo request + var o outstanding + p.mu.Lock() + if o, ok = p.pings[uint16(resp.Seq)]; ok { + // Ensure that the data matches before we delete from our map, + // so a future correct packet will be handled correctly. + if bytes.Equal(resp.Data, o.data) { + delete(p.pings, uint16(resp.Seq)) + } else { + p.vlogf("handleResponse: got response for Seq %d with mismatched data", resp.Seq) + ok = false + } + } else { + p.vlogf("handleResponse: got response for unknown Seq %d", resp.Seq) + } + p.mu.Unlock() + + if ok { + o.ch <- response{t: now} + } +} + +// Send sends an ICMP Echo Request packet to the destination, waits for a +// response, and returns the duration between when the request was sent and +// when the reply was received. +// +// If provided, "data" is sent with the packet and is compared upon receiving a +// reply. +func (p *Pinger) Send(ctx context.Context, dest net.Addr, data []byte) (time.Duration, error) { + // Use sequential sequence numbers on the assumption that we will not + // wrap around when using a single Pinger instance + p.mu.Lock() + p.seq++ + seq := p.seq + p.mu.Unlock() + + // Check whether the address is IPv4 or IPv6 to + // determine the icmp.Type and conn to use. + var conn net.PacketConn + var icmpType icmp.Type = ipv4.ICMPTypeEcho + ap, err := netip.ParseAddr(dest.String()) + if err != nil { + return 0, err + } + if ap.Is6() { + icmpType = ipv6.ICMPTypeEchoRequest + conn, err = p.getConn(ctx, v6Type) + } else { + conn, err = p.getConn(ctx, v4Type) + } + if err != nil { + return 0, err + } + + m := icmp.Message{ + Type: icmpType, + Code: 0, + Body: &icmp.Echo{ + ID: int(p.id), + Seq: int(seq), + Data: data, + }, + } + b, err := m.Marshal(nil) + if err != nil { + return 0, err + } + + // Register our response before sending since we could otherwise race a + // quick reply. + ch := make(chan response, 1) + p.mu.Lock() + p.pings[seq] = outstanding{ch: ch, data: data} + p.mu.Unlock() + + start := p.timeNow() + n, err := conn.WriteTo(b, dest) + if err != nil { + return 0, err + } else if n != len(b) { + return 0, fmt.Errorf("conn.WriteTo: got %v; want %v", n, len(b)) + } + + select { + case resp := <-ch: + if resp.err != nil { + return 0, resp.err + } + return resp.t.Sub(start), nil + + case <-ctx.Done(): + return 0, ctx.Err() + } +} diff --git a/net/ping/ping_test.go b/net/ping/ping_test.go index 5232f6ada..bbedbcad8 100644 --- a/net/ping/ping_test.go +++ b/net/ping/ping_test.go @@ -1,350 +1,350 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ping - -import ( - "context" - "errors" - "fmt" - "net" - "testing" - "time" - - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "tailscale.com/tstest" - "tailscale.com/util/mak" -) - -var ( - localhost = &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)} -) - -func TestPinger(t *testing.T) { - clock := &tstest.Clock{} - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - - p, closeP := mockPinger(t, clock) - defer closeP() - - bodyData := []byte("data goes here") - - // Start a ping in the background - r := make(chan time.Duration, 1) - go func() { - dur, err := p.Send(ctx, localhost, bodyData) - if err != nil { - t.Errorf("p.Send: %v", err) - r <- 0 - } else { - r <- dur - } - }() - - p.waitOutstanding(t, ctx, 1) - - // Fake a response from ourself - fakeResponse := mustMarshal(t, &icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: ipv4.ICMPTypeEchoReply.Protocol(), - Body: &icmp.Echo{ - ID: 1234, - Seq: 1, - Data: bodyData, - }, - }) - - const fakeDuration = 100 * time.Millisecond - p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v4Type) - - select { - case dur := <-r: - want := fakeDuration - if dur != want { - t.Errorf("wanted ping response time = %d; got %d", want, dur) - } - case <-ctx.Done(): - t.Fatal("did not get response by timeout") - } -} - -func TestV6Pinger(t *testing.T) { - if c, err := net.ListenPacket("udp6", "::1"); err != nil { - // skip test if we can't use IPv6. - t.Skipf("IPv6 not supported: %s", err) - } else { - c.Close() - } - - clock := &tstest.Clock{} - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - - p, closeP := mockPinger(t, clock) - defer closeP() - - bodyData := []byte("data goes here") - - // Start a ping in the background - r := make(chan time.Duration, 1) - go func() { - dur, err := p.Send(ctx, &net.IPAddr{IP: net.ParseIP("::")}, bodyData) - if err != nil { - t.Errorf("p.Send: %v", err) - r <- 0 - } else { - r <- dur - } - }() - - p.waitOutstanding(t, ctx, 1) - - // Fake a response from ourself - fakeResponse := mustMarshal(t, &icmp.Message{ - Type: ipv6.ICMPTypeEchoReply, - Code: ipv6.ICMPTypeEchoReply.Protocol(), - Body: &icmp.Echo{ - ID: 1234, - Seq: 1, - Data: bodyData, - }, - }) - - const fakeDuration = 100 * time.Millisecond - p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v6Type) - - select { - case dur := <-r: - want := fakeDuration - if dur != want { - t.Errorf("wanted ping response time = %d; got %d", want, dur) - } - case <-ctx.Done(): - t.Fatal("did not get response by timeout") - } -} - -func TestPingerTimeout(t *testing.T) { - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - - clock := &tstest.Clock{} - p, closeP := mockPinger(t, clock) - defer closeP() - - // Send a ping in the background - r := make(chan error, 1) - go func() { - _, err := p.Send(ctx, localhost, []byte("data goes here")) - r <- err - }() - - // Wait until we're blocking - p.waitOutstanding(t, ctx, 1) - - // Close everything down - p.cleanupOutstanding() - - // Should have got an error from the ping - err := <-r - if !errors.Is(err, net.ErrClosed) { - t.Errorf("wanted errors.Is(err, net.ErrClosed); got=%v", err) - } -} - -func TestPingerMismatch(t *testing.T) { - clock := &tstest.Clock{} - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 1*time.Second) // intentionally short - defer cancel() - - p, closeP := mockPinger(t, clock) - defer closeP() - - bodyData := []byte("data goes here") - - // Start a ping in the background - r := make(chan time.Duration, 1) - go func() { - dur, err := p.Send(ctx, localhost, bodyData) - if err != nil && !errors.Is(err, context.DeadlineExceeded) { - t.Errorf("p.Send: %v", err) - r <- 0 - } else { - r <- dur - } - }() - - p.waitOutstanding(t, ctx, 1) - - // "Receive" a bunch of intentionally malformed packets that should not - // result in the Send call above returning - badPackets := []struct { - name string - pkt *icmp.Message - }{ - { - name: "wrong type", - pkt: &icmp.Message{ - Type: ipv4.ICMPTypeDestinationUnreachable, - Code: 0, - Body: &icmp.DstUnreach{}, - }, - }, - { - name: "wrong id", - pkt: &icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: 0, - Body: &icmp.Echo{ - ID: 9999, - Seq: 1, - Data: bodyData, - }, - }, - }, - { - name: "wrong seq", - pkt: &icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: 0, - Body: &icmp.Echo{ - ID: 1234, - Seq: 5, - Data: bodyData, - }, - }, - }, - { - name: "bad body", - pkt: &icmp.Message{ - Type: ipv4.ICMPTypeEchoReply, - Code: 0, - Body: &icmp.Echo{ - ID: 1234, - Seq: 1, - - // Intentionally missing first byte - Data: bodyData[1:], - }, - }, - }, - } - - const fakeDuration = 100 * time.Millisecond - tm := clock.Now().Add(fakeDuration) - - for _, tt := range badPackets { - fakeResponse := mustMarshal(t, tt.pkt) - p.handleResponse(fakeResponse, tm, v4Type) - } - - // Also "receive" a packet that does not unmarshal as an ICMP packet - p.handleResponse([]byte("foo"), tm, v4Type) - - select { - case <-r: - t.Fatal("wanted timeout") - case <-ctx.Done(): - t.Logf("test correctly timed out") - } -} - -// udpingPacketConn will convert potentially ICMP destination addrs to UDP -// destination addrs in WriteTo so that a test that is intending to send ICMP -// traffic will instead send UDP traffic, without the higher level Pinger being -// aware of this difference. -type udpingPacketConn struct { - net.PacketConn - // destPort will be configured by the test to be the peer expected to respond to a ping. - destPort uint16 -} - -func (u *udpingPacketConn) WriteTo(body []byte, dest net.Addr) (int, error) { - switch d := dest.(type) { - case *net.IPAddr: - udpAddr := &net.UDPAddr{ - IP: d.IP, - Port: int(u.destPort), - Zone: d.Zone, - } - return u.PacketConn.WriteTo(body, udpAddr) - } - return 0, fmt.Errorf("unimplemented udpingPacketConn for %T", dest) -} - -func mockPinger(t *testing.T, clock *tstest.Clock) (*Pinger, func()) { - p := New(context.Background(), t.Logf, nil) - p.timeNow = clock.Now - p.Verbose = true - p.id = 1234 - - // In tests, we use UDP so that we can test without being root; this - // doesn't matter because we mock out the ICMP reply below to be a real - // ICMP echo reply packet. - conn4, err := net.ListenPacket("udp4", "127.0.0.1:0") - if err != nil { - t.Fatalf("net.ListenPacket: %v", err) - } - - conn6, err := net.ListenPacket("udp6", "[::]:0") - if err != nil { - t.Fatalf("net.ListenPacket: %v", err) - } - - conn4 = &udpingPacketConn{ - destPort: 12345, - PacketConn: conn4, - } - conn6 = &udpingPacketConn{ - PacketConn: conn6, - destPort: 12345, - } - - mak.Set(&p.conns, v4Type, conn4) - mak.Set(&p.conns, v6Type, conn6) - done := func() { - if err := p.Close(); err != nil { - t.Errorf("error on close: %v", err) - } - } - return p, done -} - -func mustMarshal(t *testing.T, m *icmp.Message) []byte { - t.Helper() - - b, err := m.Marshal(nil) - if err != nil { - t.Fatal(err) - } - return b -} - -func (p *Pinger) waitOutstanding(t *testing.T, ctx context.Context, count int) { - // This is a bit janky, but... we busy-loop to wait for the Send call - // to write to our map so we know that a response will be handled. - var haveMapEntry bool - for !haveMapEntry { - time.Sleep(10 * time.Millisecond) - select { - case <-ctx.Done(): - t.Error("no entry in ping map before timeout") - return - default: - } - - p.mu.Lock() - haveMapEntry = len(p.pings) == count - p.mu.Unlock() - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ping + +import ( + "context" + "errors" + "fmt" + "net" + "testing" + "time" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "tailscale.com/tstest" + "tailscale.com/util/mak" +) + +var ( + localhost = &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)} +) + +func TestPinger(t *testing.T) { + clock := &tstest.Clock{} + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + p, closeP := mockPinger(t, clock) + defer closeP() + + bodyData := []byte("data goes here") + + // Start a ping in the background + r := make(chan time.Duration, 1) + go func() { + dur, err := p.Send(ctx, localhost, bodyData) + if err != nil { + t.Errorf("p.Send: %v", err) + r <- 0 + } else { + r <- dur + } + }() + + p.waitOutstanding(t, ctx, 1) + + // Fake a response from ourself + fakeResponse := mustMarshal(t, &icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: ipv4.ICMPTypeEchoReply.Protocol(), + Body: &icmp.Echo{ + ID: 1234, + Seq: 1, + Data: bodyData, + }, + }) + + const fakeDuration = 100 * time.Millisecond + p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v4Type) + + select { + case dur := <-r: + want := fakeDuration + if dur != want { + t.Errorf("wanted ping response time = %d; got %d", want, dur) + } + case <-ctx.Done(): + t.Fatal("did not get response by timeout") + } +} + +func TestV6Pinger(t *testing.T) { + if c, err := net.ListenPacket("udp6", "::1"); err != nil { + // skip test if we can't use IPv6. + t.Skipf("IPv6 not supported: %s", err) + } else { + c.Close() + } + + clock := &tstest.Clock{} + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + p, closeP := mockPinger(t, clock) + defer closeP() + + bodyData := []byte("data goes here") + + // Start a ping in the background + r := make(chan time.Duration, 1) + go func() { + dur, err := p.Send(ctx, &net.IPAddr{IP: net.ParseIP("::")}, bodyData) + if err != nil { + t.Errorf("p.Send: %v", err) + r <- 0 + } else { + r <- dur + } + }() + + p.waitOutstanding(t, ctx, 1) + + // Fake a response from ourself + fakeResponse := mustMarshal(t, &icmp.Message{ + Type: ipv6.ICMPTypeEchoReply, + Code: ipv6.ICMPTypeEchoReply.Protocol(), + Body: &icmp.Echo{ + ID: 1234, + Seq: 1, + Data: bodyData, + }, + }) + + const fakeDuration = 100 * time.Millisecond + p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v6Type) + + select { + case dur := <-r: + want := fakeDuration + if dur != want { + t.Errorf("wanted ping response time = %d; got %d", want, dur) + } + case <-ctx.Done(): + t.Fatal("did not get response by timeout") + } +} + +func TestPingerTimeout(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + clock := &tstest.Clock{} + p, closeP := mockPinger(t, clock) + defer closeP() + + // Send a ping in the background + r := make(chan error, 1) + go func() { + _, err := p.Send(ctx, localhost, []byte("data goes here")) + r <- err + }() + + // Wait until we're blocking + p.waitOutstanding(t, ctx, 1) + + // Close everything down + p.cleanupOutstanding() + + // Should have got an error from the ping + err := <-r + if !errors.Is(err, net.ErrClosed) { + t.Errorf("wanted errors.Is(err, net.ErrClosed); got=%v", err) + } +} + +func TestPingerMismatch(t *testing.T) { + clock := &tstest.Clock{} + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 1*time.Second) // intentionally short + defer cancel() + + p, closeP := mockPinger(t, clock) + defer closeP() + + bodyData := []byte("data goes here") + + // Start a ping in the background + r := make(chan time.Duration, 1) + go func() { + dur, err := p.Send(ctx, localhost, bodyData) + if err != nil && !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("p.Send: %v", err) + r <- 0 + } else { + r <- dur + } + }() + + p.waitOutstanding(t, ctx, 1) + + // "Receive" a bunch of intentionally malformed packets that should not + // result in the Send call above returning + badPackets := []struct { + name string + pkt *icmp.Message + }{ + { + name: "wrong type", + pkt: &icmp.Message{ + Type: ipv4.ICMPTypeDestinationUnreachable, + Code: 0, + Body: &icmp.DstUnreach{}, + }, + }, + { + name: "wrong id", + pkt: &icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: 0, + Body: &icmp.Echo{ + ID: 9999, + Seq: 1, + Data: bodyData, + }, + }, + }, + { + name: "wrong seq", + pkt: &icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: 0, + Body: &icmp.Echo{ + ID: 1234, + Seq: 5, + Data: bodyData, + }, + }, + }, + { + name: "bad body", + pkt: &icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: 0, + Body: &icmp.Echo{ + ID: 1234, + Seq: 1, + + // Intentionally missing first byte + Data: bodyData[1:], + }, + }, + }, + } + + const fakeDuration = 100 * time.Millisecond + tm := clock.Now().Add(fakeDuration) + + for _, tt := range badPackets { + fakeResponse := mustMarshal(t, tt.pkt) + p.handleResponse(fakeResponse, tm, v4Type) + } + + // Also "receive" a packet that does not unmarshal as an ICMP packet + p.handleResponse([]byte("foo"), tm, v4Type) + + select { + case <-r: + t.Fatal("wanted timeout") + case <-ctx.Done(): + t.Logf("test correctly timed out") + } +} + +// udpingPacketConn will convert potentially ICMP destination addrs to UDP +// destination addrs in WriteTo so that a test that is intending to send ICMP +// traffic will instead send UDP traffic, without the higher level Pinger being +// aware of this difference. +type udpingPacketConn struct { + net.PacketConn + // destPort will be configured by the test to be the peer expected to respond to a ping. + destPort uint16 +} + +func (u *udpingPacketConn) WriteTo(body []byte, dest net.Addr) (int, error) { + switch d := dest.(type) { + case *net.IPAddr: + udpAddr := &net.UDPAddr{ + IP: d.IP, + Port: int(u.destPort), + Zone: d.Zone, + } + return u.PacketConn.WriteTo(body, udpAddr) + } + return 0, fmt.Errorf("unimplemented udpingPacketConn for %T", dest) +} + +func mockPinger(t *testing.T, clock *tstest.Clock) (*Pinger, func()) { + p := New(context.Background(), t.Logf, nil) + p.timeNow = clock.Now + p.Verbose = true + p.id = 1234 + + // In tests, we use UDP so that we can test without being root; this + // doesn't matter because we mock out the ICMP reply below to be a real + // ICMP echo reply packet. + conn4, err := net.ListenPacket("udp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.ListenPacket: %v", err) + } + + conn6, err := net.ListenPacket("udp6", "[::]:0") + if err != nil { + t.Fatalf("net.ListenPacket: %v", err) + } + + conn4 = &udpingPacketConn{ + destPort: 12345, + PacketConn: conn4, + } + conn6 = &udpingPacketConn{ + PacketConn: conn6, + destPort: 12345, + } + + mak.Set(&p.conns, v4Type, conn4) + mak.Set(&p.conns, v6Type, conn6) + done := func() { + if err := p.Close(); err != nil { + t.Errorf("error on close: %v", err) + } + } + return p, done +} + +func mustMarshal(t *testing.T, m *icmp.Message) []byte { + t.Helper() + + b, err := m.Marshal(nil) + if err != nil { + t.Fatal(err) + } + return b +} + +func (p *Pinger) waitOutstanding(t *testing.T, ctx context.Context, count int) { + // This is a bit janky, but... we busy-loop to wait for the Send call + // to write to our map so we know that a response will be handled. + var haveMapEntry bool + for !haveMapEntry { + time.Sleep(10 * time.Millisecond) + select { + case <-ctx.Done(): + t.Error("no entry in ping map before timeout") + return + default: + } + + p.mu.Lock() + haveMapEntry = len(p.pings) == count + p.mu.Unlock() + } +} diff --git a/net/portmapper/pcp_test.go b/net/portmapper/pcp_test.go index 3dece7236..8f8eef3ef 100644 --- a/net/portmapper/pcp_test.go +++ b/net/portmapper/pcp_test.go @@ -1,62 +1,62 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package portmapper - -import ( - "encoding/binary" - "net/netip" - "testing" - - "tailscale.com/net/netaddr" -) - -var examplePCPMapResponse = []byte{2, 129, 0, 0, 0, 0, 28, 32, 0, 2, 155, 237, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 129, 112, 9, 24, 241, 208, 251, 45, 157, 76, 10, 188, 17, 0, 0, 0, 4, 210, 4, 210, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 135, 180, 175, 246} - -func TestParsePCPMapResponse(t *testing.T) { - mapping, err := parsePCPMapResponse(examplePCPMapResponse) - if err != nil { - t.Fatalf("failed to parse PCP Map Response: %v", err) - } - if mapping == nil { - t.Fatalf("got nil mapping when expected non-nil") - } - expectedAddr := netip.MustParseAddrPort("135.180.175.246:1234") - if mapping.external != expectedAddr { - t.Errorf("mismatched external address, got: %v, want: %v", mapping.external, expectedAddr) - } -} - -const ( - serverResponseBit = 1 << 7 - fakeLifetimeSec = 1<<31 - 1 -) - -func buildPCPDiscoResponse(req []byte) []byte { - out := make([]byte, 24) - out[0] = pcpVersion - out[1] = req[1] | serverResponseBit - out[3] = 0 - // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail. - return out -} - -func buildPCPMapResponse(req []byte) []byte { - out := make([]byte, 24+36) - out[0] = pcpVersion - out[1] = req[1] | serverResponseBit - out[3] = 0 - binary.BigEndian.PutUint32(out[4:8], 1<<30) - // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail. - mapResp := out[24:] - mapReq := req[24:] - // copy nonce, protocol and internal port - copy(mapResp[:13], mapReq[:13]) - copy(mapResp[16:18], mapReq[16:18]) - // assign external port - binary.BigEndian.PutUint16(mapResp[18:20], 4242) - assignedIP := netaddr.IPv4(127, 0, 0, 1) - assignedIP16 := assignedIP.As16() - copy(mapResp[20:36], assignedIP16[:]) - return out -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package portmapper + +import ( + "encoding/binary" + "net/netip" + "testing" + + "tailscale.com/net/netaddr" +) + +var examplePCPMapResponse = []byte{2, 129, 0, 0, 0, 0, 28, 32, 0, 2, 155, 237, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 129, 112, 9, 24, 241, 208, 251, 45, 157, 76, 10, 188, 17, 0, 0, 0, 4, 210, 4, 210, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 135, 180, 175, 246} + +func TestParsePCPMapResponse(t *testing.T) { + mapping, err := parsePCPMapResponse(examplePCPMapResponse) + if err != nil { + t.Fatalf("failed to parse PCP Map Response: %v", err) + } + if mapping == nil { + t.Fatalf("got nil mapping when expected non-nil") + } + expectedAddr := netip.MustParseAddrPort("135.180.175.246:1234") + if mapping.external != expectedAddr { + t.Errorf("mismatched external address, got: %v, want: %v", mapping.external, expectedAddr) + } +} + +const ( + serverResponseBit = 1 << 7 + fakeLifetimeSec = 1<<31 - 1 +) + +func buildPCPDiscoResponse(req []byte) []byte { + out := make([]byte, 24) + out[0] = pcpVersion + out[1] = req[1] | serverResponseBit + out[3] = 0 + // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail. + return out +} + +func buildPCPMapResponse(req []byte) []byte { + out := make([]byte, 24+36) + out[0] = pcpVersion + out[1] = req[1] | serverResponseBit + out[3] = 0 + binary.BigEndian.PutUint32(out[4:8], 1<<30) + // Do not put an epoch time in 8:12, when we start using it, tests that use it should fail. + mapResp := out[24:] + mapReq := req[24:] + // copy nonce, protocol and internal port + copy(mapResp[:13], mapReq[:13]) + copy(mapResp[16:18], mapReq[16:18]) + // assign external port + binary.BigEndian.PutUint16(mapResp[18:20], 4242) + assignedIP := netaddr.IPv4(127, 0, 0, 1) + assignedIP16 := assignedIP.As16() + copy(mapResp[20:36], assignedIP16[:]) + return out +} diff --git a/net/proxymux/mux.go b/net/proxymux/mux.go index 12c3107de..ff5aaff3b 100644 --- a/net/proxymux/mux.go +++ b/net/proxymux/mux.go @@ -1,144 +1,144 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package proxymux splits a net.Listener in two, routing SOCKS5 -// connections to one and HTTP requests to the other. -// -// It allows for hosting both a SOCKS5 proxy and an HTTP proxy on the -// same listener. -package proxymux - -import ( - "io" - "net" - "sync" - "time" -) - -// SplitSOCKSAndHTTP accepts connections on ln and passes connections -// through to either socksListener or httpListener, depending the -// first byte sent by the client. -func SplitSOCKSAndHTTP(ln net.Listener) (socksListener, httpListener net.Listener) { - sl := &listener{ - addr: ln.Addr(), - c: make(chan net.Conn), - closed: make(chan struct{}), - } - hl := &listener{ - addr: ln.Addr(), - c: make(chan net.Conn), - closed: make(chan struct{}), - } - - go splitSOCKSAndHTTPListener(ln, sl, hl) - - return sl, hl -} - -func splitSOCKSAndHTTPListener(ln net.Listener, sl, hl *listener) { - for { - conn, err := ln.Accept() - if err != nil { - sl.Close() - hl.Close() - return - } - go routeConn(conn, sl, hl) - } -} - -func routeConn(c net.Conn, socksListener, httpListener *listener) { - if err := c.SetReadDeadline(time.Now().Add(15 * time.Second)); err != nil { - c.Close() - return - } - - var b [1]byte - if _, err := io.ReadFull(c, b[:]); err != nil { - c.Close() - return - } - - if err := c.SetReadDeadline(time.Time{}); err != nil { - c.Close() - return - } - - conn := &connWithOneByte{ - Conn: c, - b: b[0], - } - - // First byte of a SOCKS5 session is a version byte set to 5. - var ln *listener - if b[0] == 5 { - ln = socksListener - } else { - ln = httpListener - } - select { - case ln.c <- conn: - case <-ln.closed: - c.Close() - } -} - -type listener struct { - addr net.Addr - c chan net.Conn - mu sync.Mutex // serializes close() on closed. It's okay to receive on closed without locking. - closed chan struct{} -} - -func (ln *listener) Accept() (net.Conn, error) { - // Once closed, reliably stay closed, don't race with attempts at - // further connections. - select { - case <-ln.closed: - return nil, net.ErrClosed - default: - } - select { - case ret := <-ln.c: - return ret, nil - case <-ln.closed: - return nil, net.ErrClosed - } -} - -func (ln *listener) Close() error { - ln.mu.Lock() - defer ln.mu.Unlock() - select { - case <-ln.closed: - // Already closed - default: - close(ln.closed) - } - return nil -} - -func (ln *listener) Addr() net.Addr { - return ln.addr -} - -// connWithOneByte is a net.Conn that returns b for the first read -// request, then forwards everything else to Conn. -type connWithOneByte struct { - net.Conn - - b byte - bRead bool -} - -func (c *connWithOneByte) Read(bs []byte) (int, error) { - if c.bRead { - return c.Conn.Read(bs) - } - if len(bs) == 0 { - return 0, nil - } - c.bRead = true - bs[0] = c.b - return 1, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package proxymux splits a net.Listener in two, routing SOCKS5 +// connections to one and HTTP requests to the other. +// +// It allows for hosting both a SOCKS5 proxy and an HTTP proxy on the +// same listener. +package proxymux + +import ( + "io" + "net" + "sync" + "time" +) + +// SplitSOCKSAndHTTP accepts connections on ln and passes connections +// through to either socksListener or httpListener, depending the +// first byte sent by the client. +func SplitSOCKSAndHTTP(ln net.Listener) (socksListener, httpListener net.Listener) { + sl := &listener{ + addr: ln.Addr(), + c: make(chan net.Conn), + closed: make(chan struct{}), + } + hl := &listener{ + addr: ln.Addr(), + c: make(chan net.Conn), + closed: make(chan struct{}), + } + + go splitSOCKSAndHTTPListener(ln, sl, hl) + + return sl, hl +} + +func splitSOCKSAndHTTPListener(ln net.Listener, sl, hl *listener) { + for { + conn, err := ln.Accept() + if err != nil { + sl.Close() + hl.Close() + return + } + go routeConn(conn, sl, hl) + } +} + +func routeConn(c net.Conn, socksListener, httpListener *listener) { + if err := c.SetReadDeadline(time.Now().Add(15 * time.Second)); err != nil { + c.Close() + return + } + + var b [1]byte + if _, err := io.ReadFull(c, b[:]); err != nil { + c.Close() + return + } + + if err := c.SetReadDeadline(time.Time{}); err != nil { + c.Close() + return + } + + conn := &connWithOneByte{ + Conn: c, + b: b[0], + } + + // First byte of a SOCKS5 session is a version byte set to 5. + var ln *listener + if b[0] == 5 { + ln = socksListener + } else { + ln = httpListener + } + select { + case ln.c <- conn: + case <-ln.closed: + c.Close() + } +} + +type listener struct { + addr net.Addr + c chan net.Conn + mu sync.Mutex // serializes close() on closed. It's okay to receive on closed without locking. + closed chan struct{} +} + +func (ln *listener) Accept() (net.Conn, error) { + // Once closed, reliably stay closed, don't race with attempts at + // further connections. + select { + case <-ln.closed: + return nil, net.ErrClosed + default: + } + select { + case ret := <-ln.c: + return ret, nil + case <-ln.closed: + return nil, net.ErrClosed + } +} + +func (ln *listener) Close() error { + ln.mu.Lock() + defer ln.mu.Unlock() + select { + case <-ln.closed: + // Already closed + default: + close(ln.closed) + } + return nil +} + +func (ln *listener) Addr() net.Addr { + return ln.addr +} + +// connWithOneByte is a net.Conn that returns b for the first read +// request, then forwards everything else to Conn. +type connWithOneByte struct { + net.Conn + + b byte + bRead bool +} + +func (c *connWithOneByte) Read(bs []byte) (int, error) { + if c.bRead { + return c.Conn.Read(bs) + } + if len(bs) == 0 { + return 0, nil + } + c.bRead = true + bs[0] = c.b + return 1, nil +} diff --git a/net/routetable/routetable_darwin.go b/net/routetable/routetable_darwin.go index 7de80a662..7f525ae32 100644 --- a/net/routetable/routetable_darwin.go +++ b/net/routetable/routetable_darwin.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin - -package routetable - -import "golang.org/x/sys/unix" - -const ( - ribType = unix.NET_RT_DUMP2 - parseType = unix.NET_RT_IFLIST2 - rmExpectedType = unix.RTM_GET2 - - // Skip routes that were cloned from a parent - skipFlags = unix.RTF_WASCLONED -) - -var flags = map[int]string{ - unix.RTF_BLACKHOLE: "blackhole", - unix.RTF_BROADCAST: "broadcast", - unix.RTF_GATEWAY: "gateway", - unix.RTF_GLOBAL: "global", - unix.RTF_HOST: "host", - unix.RTF_IFSCOPE: "ifscope", - unix.RTF_LOCAL: "local", - unix.RTF_MULTICAST: "multicast", - unix.RTF_REJECT: "reject", - unix.RTF_ROUTER: "router", - unix.RTF_STATIC: "static", - unix.RTF_UP: "up", - // More obscure flags, just to have full coverage. - unix.RTF_LLINFO: "{RTF_LLINFO}", - unix.RTF_PRCLONING: "{RTF_PRCLONING}", - unix.RTF_CLONING: "{RTF_CLONING}", -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin + +package routetable + +import "golang.org/x/sys/unix" + +const ( + ribType = unix.NET_RT_DUMP2 + parseType = unix.NET_RT_IFLIST2 + rmExpectedType = unix.RTM_GET2 + + // Skip routes that were cloned from a parent + skipFlags = unix.RTF_WASCLONED +) + +var flags = map[int]string{ + unix.RTF_BLACKHOLE: "blackhole", + unix.RTF_BROADCAST: "broadcast", + unix.RTF_GATEWAY: "gateway", + unix.RTF_GLOBAL: "global", + unix.RTF_HOST: "host", + unix.RTF_IFSCOPE: "ifscope", + unix.RTF_LOCAL: "local", + unix.RTF_MULTICAST: "multicast", + unix.RTF_REJECT: "reject", + unix.RTF_ROUTER: "router", + unix.RTF_STATIC: "static", + unix.RTF_UP: "up", + // More obscure flags, just to have full coverage. + unix.RTF_LLINFO: "{RTF_LLINFO}", + unix.RTF_PRCLONING: "{RTF_PRCLONING}", + unix.RTF_CLONING: "{RTF_CLONING}", +} diff --git a/net/routetable/routetable_freebsd.go b/net/routetable/routetable_freebsd.go index aa4e03c41..8e57a3302 100644 --- a/net/routetable/routetable_freebsd.go +++ b/net/routetable/routetable_freebsd.go @@ -1,28 +1,28 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build freebsd - -package routetable - -import "golang.org/x/sys/unix" - -const ( - ribType = unix.NET_RT_DUMP - parseType = unix.NET_RT_IFLIST - rmExpectedType = unix.RTM_GET - - // Nothing to skip - skipFlags = 0 -) - -var flags = map[int]string{ - unix.RTF_BLACKHOLE: "blackhole", - unix.RTF_BROADCAST: "broadcast", - unix.RTF_GATEWAY: "gateway", - unix.RTF_HOST: "host", - unix.RTF_MULTICAST: "multicast", - unix.RTF_REJECT: "reject", - unix.RTF_STATIC: "static", - unix.RTF_UP: "up", -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build freebsd + +package routetable + +import "golang.org/x/sys/unix" + +const ( + ribType = unix.NET_RT_DUMP + parseType = unix.NET_RT_IFLIST + rmExpectedType = unix.RTM_GET + + // Nothing to skip + skipFlags = 0 +) + +var flags = map[int]string{ + unix.RTF_BLACKHOLE: "blackhole", + unix.RTF_BROADCAST: "broadcast", + unix.RTF_GATEWAY: "gateway", + unix.RTF_HOST: "host", + unix.RTF_MULTICAST: "multicast", + unix.RTF_REJECT: "reject", + unix.RTF_STATIC: "static", + unix.RTF_UP: "up", +} diff --git a/net/routetable/routetable_other.go b/net/routetable/routetable_other.go index 521fe1911..35c83e374 100644 --- a/net/routetable/routetable_other.go +++ b/net/routetable/routetable_other.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux && !darwin && !freebsd - -package routetable - -import ( - "errors" - "runtime" -) - -var errUnsupported = errors.New("cannot get route table on platform " + runtime.GOOS) - -func Get(max int) ([]RouteEntry, error) { - return nil, errUnsupported -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !darwin && !freebsd + +package routetable + +import ( + "errors" + "runtime" +) + +var errUnsupported = errors.New("cannot get route table on platform " + runtime.GOOS) + +func Get(max int) ([]RouteEntry, error) { + return nil, errUnsupported +} diff --git a/net/sockstats/sockstats.go b/net/sockstats/sockstats.go index fb524a5c5..715c1ee06 100644 --- a/net/sockstats/sockstats.go +++ b/net/sockstats/sockstats.go @@ -1,121 +1,121 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package sockstats collects statistics about network sockets used by -// the Tailscale client. The context where sockets are used must be -// instrumented with the WithSockStats() function. -// -// Only available on POSIX platforms when built with Tailscale's fork of Go. -package sockstats - -import ( - "context" - - "tailscale.com/net/netmon" - "tailscale.com/types/logger" -) - -// SockStats contains statistics for sockets instrumented with the -// WithSockStats() function -type SockStats struct { - Stats map[Label]SockStat - CurrentInterfaceCellular bool -} - -// SockStat contains the sent and received bytes for a socket instrumented with -// the WithSockStats() function. -type SockStat struct { - TxBytes uint64 - RxBytes uint64 -} - -// Label is an identifier for a socket that stats are collected for. A finite -// set of values that may be used to label a socket to encourage grouping and -// to make storage more efficient. -type Label uint8 - -//go:generate go run golang.org/x/tools/cmd/stringer -type Label -trimprefix Label - -// Labels are named after the package and function/struct that uses the socket. -// Values may be persisted and thus existing entries should not be re-numbered. -const ( - LabelControlClientAuto Label = 0 // control/controlclient/auto.go - LabelControlClientDialer Label = 1 // control/controlhttp/client.go - LabelDERPHTTPClient Label = 2 // derp/derphttp/derphttp_client.go - LabelLogtailLogger Label = 3 // logtail/logtail.go - LabelDNSForwarderDoH Label = 4 // net/dns/resolver/forwarder.go - LabelDNSForwarderUDP Label = 5 // net/dns/resolver/forwarder.go - LabelNetcheckClient Label = 6 // net/netcheck/netcheck.go - LabelPortmapperClient Label = 7 // net/portmapper/portmapper.go - LabelMagicsockConnUDP4 Label = 8 // wgengine/magicsock/magicsock.go - LabelMagicsockConnUDP6 Label = 9 // wgengine/magicsock/magicsock.go - LabelNetlogLogger Label = 10 // wgengine/netlog/logger.go - LabelSockstatlogLogger Label = 11 // log/sockstatlog/logger.go - LabelDNSForwarderTCP Label = 12 // net/dns/resolver/forwarder.go -) - -// WithSockStats instruments a context so that sockets created with it will -// have their statistics collected. -func WithSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context { - return withSockStats(ctx, label, logf) -} - -// Get returns the current socket statistics. -func Get() *SockStats { - return get() -} - -// InterfaceSockStats contains statistics for sockets instrumented with the -// WithSockStats() function, broken down by interface. The statistics may be a -// subset of the total if interfaces were added after the instrumented socket -// was created. -type InterfaceSockStats struct { - Stats map[Label]InterfaceSockStat - Interfaces []string -} - -// InterfaceSockStat contains the per-interface sent and received bytes for a -// socket instrumented with the WithSockStats() function. -type InterfaceSockStat struct { - TxBytesByInterface map[string]uint64 - RxBytesByInterface map[string]uint64 -} - -// GetWithInterfaces is a variant of Get that returns the current socket -// statistics broken down by interface. It is slightly more expensive than Get. -func GetInterfaces() *InterfaceSockStats { - return getInterfaces() -} - -// ValidationSockStats contains external validation numbers for sockets -// instrumented with WithSockStats. It may be a subset of the all sockets, -// depending on what externa measurement mechanisms the platform supports. -type ValidationSockStats struct { - Stats map[Label]ValidationSockStat -} - -// ValidationSockStat contains the validation bytes for a socket instrumented -// with WithSockStats. -type ValidationSockStat struct { - TxBytes uint64 - RxBytes uint64 -} - -// GetValidation is a variant of Get that returns external validation numbers -// for stats. It is more expensive than Get and should be used in debug -// interfaces only. -func GetValidation() *ValidationSockStats { - return getValidation() -} - -// SetNetMon configures the sockstats package to monitor the active -// interface, so that per-interface stats can be collected. -func SetNetMon(netMon *netmon.Monitor) { - setNetMon(netMon) -} - -// DebugInfo returns a string containing debug information about the tracked -// statistics. -func DebugInfo() string { - return debugInfo() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package sockstats collects statistics about network sockets used by +// the Tailscale client. The context where sockets are used must be +// instrumented with the WithSockStats() function. +// +// Only available on POSIX platforms when built with Tailscale's fork of Go. +package sockstats + +import ( + "context" + + "tailscale.com/net/netmon" + "tailscale.com/types/logger" +) + +// SockStats contains statistics for sockets instrumented with the +// WithSockStats() function +type SockStats struct { + Stats map[Label]SockStat + CurrentInterfaceCellular bool +} + +// SockStat contains the sent and received bytes for a socket instrumented with +// the WithSockStats() function. +type SockStat struct { + TxBytes uint64 + RxBytes uint64 +} + +// Label is an identifier for a socket that stats are collected for. A finite +// set of values that may be used to label a socket to encourage grouping and +// to make storage more efficient. +type Label uint8 + +//go:generate go run golang.org/x/tools/cmd/stringer -type Label -trimprefix Label + +// Labels are named after the package and function/struct that uses the socket. +// Values may be persisted and thus existing entries should not be re-numbered. +const ( + LabelControlClientAuto Label = 0 // control/controlclient/auto.go + LabelControlClientDialer Label = 1 // control/controlhttp/client.go + LabelDERPHTTPClient Label = 2 // derp/derphttp/derphttp_client.go + LabelLogtailLogger Label = 3 // logtail/logtail.go + LabelDNSForwarderDoH Label = 4 // net/dns/resolver/forwarder.go + LabelDNSForwarderUDP Label = 5 // net/dns/resolver/forwarder.go + LabelNetcheckClient Label = 6 // net/netcheck/netcheck.go + LabelPortmapperClient Label = 7 // net/portmapper/portmapper.go + LabelMagicsockConnUDP4 Label = 8 // wgengine/magicsock/magicsock.go + LabelMagicsockConnUDP6 Label = 9 // wgengine/magicsock/magicsock.go + LabelNetlogLogger Label = 10 // wgengine/netlog/logger.go + LabelSockstatlogLogger Label = 11 // log/sockstatlog/logger.go + LabelDNSForwarderTCP Label = 12 // net/dns/resolver/forwarder.go +) + +// WithSockStats instruments a context so that sockets created with it will +// have their statistics collected. +func WithSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context { + return withSockStats(ctx, label, logf) +} + +// Get returns the current socket statistics. +func Get() *SockStats { + return get() +} + +// InterfaceSockStats contains statistics for sockets instrumented with the +// WithSockStats() function, broken down by interface. The statistics may be a +// subset of the total if interfaces were added after the instrumented socket +// was created. +type InterfaceSockStats struct { + Stats map[Label]InterfaceSockStat + Interfaces []string +} + +// InterfaceSockStat contains the per-interface sent and received bytes for a +// socket instrumented with the WithSockStats() function. +type InterfaceSockStat struct { + TxBytesByInterface map[string]uint64 + RxBytesByInterface map[string]uint64 +} + +// GetWithInterfaces is a variant of Get that returns the current socket +// statistics broken down by interface. It is slightly more expensive than Get. +func GetInterfaces() *InterfaceSockStats { + return getInterfaces() +} + +// ValidationSockStats contains external validation numbers for sockets +// instrumented with WithSockStats. It may be a subset of the all sockets, +// depending on what externa measurement mechanisms the platform supports. +type ValidationSockStats struct { + Stats map[Label]ValidationSockStat +} + +// ValidationSockStat contains the validation bytes for a socket instrumented +// with WithSockStats. +type ValidationSockStat struct { + TxBytes uint64 + RxBytes uint64 +} + +// GetValidation is a variant of Get that returns external validation numbers +// for stats. It is more expensive than Get and should be used in debug +// interfaces only. +func GetValidation() *ValidationSockStats { + return getValidation() +} + +// SetNetMon configures the sockstats package to monitor the active +// interface, so that per-interface stats can be collected. +func SetNetMon(netMon *netmon.Monitor) { + setNetMon(netMon) +} + +// DebugInfo returns a string containing debug information about the tracked +// statistics. +func DebugInfo() string { + return debugInfo() +} diff --git a/net/sockstats/sockstats_noop.go b/net/sockstats/sockstats_noop.go index 797fdc42b..96723111a 100644 --- a/net/sockstats/sockstats_noop.go +++ b/net/sockstats/sockstats_noop.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !tailscale_go || !(darwin || ios || android || ts_enable_sockstats) - -package sockstats - -import ( - "context" - - "tailscale.com/net/netmon" - "tailscale.com/types/logger" -) - -const IsAvailable = false - -func withSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context { - return ctx -} - -func get() *SockStats { - return nil -} - -func getInterfaces() *InterfaceSockStats { - return nil -} - -func getValidation() *ValidationSockStats { - return nil -} - -func setNetMon(netMon *netmon.Monitor) { -} - -func debugInfo() string { - return "" -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !tailscale_go || !(darwin || ios || android || ts_enable_sockstats) + +package sockstats + +import ( + "context" + + "tailscale.com/net/netmon" + "tailscale.com/types/logger" +) + +const IsAvailable = false + +func withSockStats(ctx context.Context, label Label, logf logger.Logf) context.Context { + return ctx +} + +func get() *SockStats { + return nil +} + +func getInterfaces() *InterfaceSockStats { + return nil +} + +func getValidation() *ValidationSockStats { + return nil +} + +func setNetMon(netMon *netmon.Monitor) { +} + +func debugInfo() string { + return "" +} diff --git a/net/sockstats/sockstats_tsgo_darwin.go b/net/sockstats/sockstats_tsgo_darwin.go index 4b03ed616..321d32e04 100644 --- a/net/sockstats/sockstats_tsgo_darwin.go +++ b/net/sockstats/sockstats_tsgo_darwin.go @@ -1,30 +1,30 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build tailscale_go && (darwin || ios) - -package sockstats - -import ( - "syscall" - - "golang.org/x/sys/unix" -) - -func init() { - tcpConnStats = darwinTcpConnStats -} - -func darwinTcpConnStats(c syscall.RawConn) (tx, rx uint64) { - c.Control(func(fd uintptr) { - if rawInfo, err := unix.GetsockoptTCPConnectionInfo( - int(fd), - unix.IPPROTO_TCP, - unix.TCP_CONNECTION_INFO, - ); err == nil { - tx = uint64(rawInfo.Txbytes) - rx = uint64(rawInfo.Rxbytes) - } - }) - return -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build tailscale_go && (darwin || ios) + +package sockstats + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func init() { + tcpConnStats = darwinTcpConnStats +} + +func darwinTcpConnStats(c syscall.RawConn) (tx, rx uint64) { + c.Control(func(fd uintptr) { + if rawInfo, err := unix.GetsockoptTCPConnectionInfo( + int(fd), + unix.IPPROTO_TCP, + unix.TCP_CONNECTION_INFO, + ); err == nil { + tx = uint64(rawInfo.Txbytes) + rx = uint64(rawInfo.Rxbytes) + } + }) + return +} diff --git a/net/speedtest/speedtest.go b/net/speedtest/speedtest.go index 89639c12d..7ab0881cc 100644 --- a/net/speedtest/speedtest.go +++ b/net/speedtest/speedtest.go @@ -1,87 +1,87 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package speedtest contains both server and client code for -// running speedtests between tailscale nodes. -package speedtest - -import ( - "time" -) - -const ( - blockSize = 2 * 1024 * 1024 // size of the block of data to send - MinDuration = 5 * time.Second // minimum duration for a test - DefaultDuration = MinDuration // default duration for a test - MaxDuration = 30 * time.Second // maximum duration for a test - version = 2 // value used when comparing client and server versions - increment = time.Second // increment to display results for, in seconds - minInterval = 10 * time.Millisecond // minimum interval length for a result to be included - DefaultPort = 20333 -) - -// config is the initial message sent to the server, that contains information on how to -// conduct the test. -type config struct { - Version int `json:"version"` - TestDuration time.Duration `json:"time"` - Direction Direction `json:"direction"` -} - -// configResponse is the response to the testConfig message. If the server has an -// error with the config, the Error variable will hold that error value. -type configResponse struct { - Error string `json:"error,omitempty"` -} - -// This represents the Result of a speedtest within a specific interval -type Result struct { - Bytes int // number of bytes sent/received during the interval - IntervalStart time.Time // start of the interval - IntervalEnd time.Time // end of the interval - Total bool // if true, this result struct represents the entire test, rather than a segment of the test -} - -func (r Result) MBitsPerSecond() float64 { - return r.MegaBits() / r.IntervalEnd.Sub(r.IntervalStart).Seconds() -} - -func (r Result) MegaBytes() float64 { - return float64(r.Bytes) / 1000000.0 -} - -func (r Result) MegaBits() float64 { - return r.MegaBytes() * 8.0 -} - -func (r Result) Interval() time.Duration { - return r.IntervalEnd.Sub(r.IntervalStart) -} - -type Direction int - -const ( - Download Direction = iota - Upload -) - -func (d Direction) String() string { - switch d { - case Upload: - return "upload" - case Download: - return "download" - default: - return "" - } -} - -func (d *Direction) Reverse() { - switch *d { - case Upload: - *d = Download - case Download: - *d = Upload - default: - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package speedtest contains both server and client code for +// running speedtests between tailscale nodes. +package speedtest + +import ( + "time" +) + +const ( + blockSize = 2 * 1024 * 1024 // size of the block of data to send + MinDuration = 5 * time.Second // minimum duration for a test + DefaultDuration = MinDuration // default duration for a test + MaxDuration = 30 * time.Second // maximum duration for a test + version = 2 // value used when comparing client and server versions + increment = time.Second // increment to display results for, in seconds + minInterval = 10 * time.Millisecond // minimum interval length for a result to be included + DefaultPort = 20333 +) + +// config is the initial message sent to the server, that contains information on how to +// conduct the test. +type config struct { + Version int `json:"version"` + TestDuration time.Duration `json:"time"` + Direction Direction `json:"direction"` +} + +// configResponse is the response to the testConfig message. If the server has an +// error with the config, the Error variable will hold that error value. +type configResponse struct { + Error string `json:"error,omitempty"` +} + +// This represents the Result of a speedtest within a specific interval +type Result struct { + Bytes int // number of bytes sent/received during the interval + IntervalStart time.Time // start of the interval + IntervalEnd time.Time // end of the interval + Total bool // if true, this result struct represents the entire test, rather than a segment of the test +} + +func (r Result) MBitsPerSecond() float64 { + return r.MegaBits() / r.IntervalEnd.Sub(r.IntervalStart).Seconds() +} + +func (r Result) MegaBytes() float64 { + return float64(r.Bytes) / 1000000.0 +} + +func (r Result) MegaBits() float64 { + return r.MegaBytes() * 8.0 +} + +func (r Result) Interval() time.Duration { + return r.IntervalEnd.Sub(r.IntervalStart) +} + +type Direction int + +const ( + Download Direction = iota + Upload +) + +func (d Direction) String() string { + switch d { + case Upload: + return "upload" + case Download: + return "download" + default: + return "" + } +} + +func (d *Direction) Reverse() { + switch *d { + case Upload: + *d = Download + case Download: + *d = Upload + default: + } +} diff --git a/net/speedtest/speedtest_client.go b/net/speedtest/speedtest_client.go index cc34c468c..299a12a8d 100644 --- a/net/speedtest/speedtest_client.go +++ b/net/speedtest/speedtest_client.go @@ -1,41 +1,41 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package speedtest - -import ( - "encoding/json" - "errors" - "net" - "time" -) - -// RunClient dials the given address and starts a speedtest. -// It returns any errors that come up in the tests. -// If there are no errors in the test, it returns a slice of results. -func RunClient(direction Direction, duration time.Duration, host string) ([]Result, error) { - conn, err := net.Dial("tcp", host) - if err != nil { - return nil, err - } - - conf := config{TestDuration: duration, Version: version, Direction: direction} - - defer conn.Close() - encoder := json.NewEncoder(conn) - - if err = encoder.Encode(conf); err != nil { - return nil, err - } - - var response configResponse - decoder := json.NewDecoder(conn) - if err = decoder.Decode(&response); err != nil { - return nil, err - } - if response.Error != "" { - return nil, errors.New(response.Error) - } - - return doTest(conn, conf) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package speedtest + +import ( + "encoding/json" + "errors" + "net" + "time" +) + +// RunClient dials the given address and starts a speedtest. +// It returns any errors that come up in the tests. +// If there are no errors in the test, it returns a slice of results. +func RunClient(direction Direction, duration time.Duration, host string) ([]Result, error) { + conn, err := net.Dial("tcp", host) + if err != nil { + return nil, err + } + + conf := config{TestDuration: duration, Version: version, Direction: direction} + + defer conn.Close() + encoder := json.NewEncoder(conn) + + if err = encoder.Encode(conf); err != nil { + return nil, err + } + + var response configResponse + decoder := json.NewDecoder(conn) + if err = decoder.Decode(&response); err != nil { + return nil, err + } + if response.Error != "" { + return nil, errors.New(response.Error) + } + + return doTest(conn, conf) +} diff --git a/net/speedtest/speedtest_server.go b/net/speedtest/speedtest_server.go index d2673464e..9dd78b195 100644 --- a/net/speedtest/speedtest_server.go +++ b/net/speedtest/speedtest_server.go @@ -1,146 +1,146 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package speedtest - -import ( - "crypto/rand" - "encoding/json" - "errors" - "fmt" - "io" - "net" - "time" -) - -// Serve starts up the server on a given host and port pair. It starts to listen for -// connections and handles each one in a goroutine. Because it runs in an infinite loop, -// this function only returns if any of the speedtests return with errors, or if the -// listener is closed. -func Serve(l net.Listener) error { - for { - conn, err := l.Accept() - if errors.Is(err, net.ErrClosed) { - return nil - } - if err != nil { - return err - } - err = handleConnection(conn) - if err != nil { - return err - } - } -} - -// handleConnection handles the initial exchange between the server and the client. -// It reads the testconfig message into a config struct. If any errors occur with -// the testconfig (specifically, if there is a version mismatch), it will return those -// errors to the client with a configResponse. After the exchange, it will start -// the speed test. -func handleConnection(conn net.Conn) error { - defer conn.Close() - var conf config - - decoder := json.NewDecoder(conn) - err := decoder.Decode(&conf) - encoder := json.NewEncoder(conn) - - // Both return and encode errors that occurred before the test started. - if err != nil { - encoder.Encode(configResponse{Error: err.Error()}) - return err - } - - // The server should always be doing the opposite of what the client is doing. - conf.Direction.Reverse() - - if conf.Version != version { - err = fmt.Errorf("version mismatch! Server is version %d, client is version %d", version, conf.Version) - encoder.Encode(configResponse{Error: err.Error()}) - return err - } - - // Start the test - encoder.Encode(configResponse{}) - _, err = doTest(conn, conf) - return err -} - -// TODO include code to detect whether the code is direct vs DERP - -// doTest contains the code to run both the upload and download speedtest. -// the direction value in the config parameter determines which test to run. -func doTest(conn net.Conn, conf config) ([]Result, error) { - bufferData := make([]byte, blockSize) - - intervalBytes := 0 - totalBytes := 0 - - var currentTime time.Time - var results []Result - - if conf.Direction == Download { - conn.SetReadDeadline(time.Now().Add(conf.TestDuration).Add(5 * time.Second)) - } else { - _, err := rand.Read(bufferData) - if err != nil { - return nil, err - } - - } - - startTime := time.Now() - lastCalculated := startTime - -SpeedTestLoop: - for { - var n int - var err error - - if conf.Direction == Download { - n, err = io.ReadFull(conn, bufferData) - switch err { - case io.EOF, io.ErrUnexpectedEOF: - break SpeedTestLoop - case nil: - // successful read - default: - return nil, fmt.Errorf("unexpected error has occurred: %w", err) - } - } else { - n, err = conn.Write(bufferData) - if err != nil { - // If the write failed, there is most likely something wrong with the connection. - return nil, fmt.Errorf("upload failed: %w", err) - } - } - intervalBytes += n - - currentTime = time.Now() - // checks if the current time is more or equal to the lastCalculated time plus the increment - if currentTime.Sub(lastCalculated) >= increment { - results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false}) - lastCalculated = currentTime - totalBytes += intervalBytes - intervalBytes = 0 - } - - if conf.Direction == Upload && currentTime.Sub(startTime) > conf.TestDuration { - break SpeedTestLoop - } - } - - // get last segment - if currentTime.Sub(lastCalculated) > minInterval { - results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false}) - } - - // get total - totalBytes += intervalBytes - if currentTime.Sub(startTime) > minInterval { - results = append(results, Result{Bytes: totalBytes, IntervalStart: startTime, IntervalEnd: currentTime, Total: true}) - } - - return results, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package speedtest + +import ( + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "time" +) + +// Serve starts up the server on a given host and port pair. It starts to listen for +// connections and handles each one in a goroutine. Because it runs in an infinite loop, +// this function only returns if any of the speedtests return with errors, or if the +// listener is closed. +func Serve(l net.Listener) error { + for { + conn, err := l.Accept() + if errors.Is(err, net.ErrClosed) { + return nil + } + if err != nil { + return err + } + err = handleConnection(conn) + if err != nil { + return err + } + } +} + +// handleConnection handles the initial exchange between the server and the client. +// It reads the testconfig message into a config struct. If any errors occur with +// the testconfig (specifically, if there is a version mismatch), it will return those +// errors to the client with a configResponse. After the exchange, it will start +// the speed test. +func handleConnection(conn net.Conn) error { + defer conn.Close() + var conf config + + decoder := json.NewDecoder(conn) + err := decoder.Decode(&conf) + encoder := json.NewEncoder(conn) + + // Both return and encode errors that occurred before the test started. + if err != nil { + encoder.Encode(configResponse{Error: err.Error()}) + return err + } + + // The server should always be doing the opposite of what the client is doing. + conf.Direction.Reverse() + + if conf.Version != version { + err = fmt.Errorf("version mismatch! Server is version %d, client is version %d", version, conf.Version) + encoder.Encode(configResponse{Error: err.Error()}) + return err + } + + // Start the test + encoder.Encode(configResponse{}) + _, err = doTest(conn, conf) + return err +} + +// TODO include code to detect whether the code is direct vs DERP + +// doTest contains the code to run both the upload and download speedtest. +// the direction value in the config parameter determines which test to run. +func doTest(conn net.Conn, conf config) ([]Result, error) { + bufferData := make([]byte, blockSize) + + intervalBytes := 0 + totalBytes := 0 + + var currentTime time.Time + var results []Result + + if conf.Direction == Download { + conn.SetReadDeadline(time.Now().Add(conf.TestDuration).Add(5 * time.Second)) + } else { + _, err := rand.Read(bufferData) + if err != nil { + return nil, err + } + + } + + startTime := time.Now() + lastCalculated := startTime + +SpeedTestLoop: + for { + var n int + var err error + + if conf.Direction == Download { + n, err = io.ReadFull(conn, bufferData) + switch err { + case io.EOF, io.ErrUnexpectedEOF: + break SpeedTestLoop + case nil: + // successful read + default: + return nil, fmt.Errorf("unexpected error has occurred: %w", err) + } + } else { + n, err = conn.Write(bufferData) + if err != nil { + // If the write failed, there is most likely something wrong with the connection. + return nil, fmt.Errorf("upload failed: %w", err) + } + } + intervalBytes += n + + currentTime = time.Now() + // checks if the current time is more or equal to the lastCalculated time plus the increment + if currentTime.Sub(lastCalculated) >= increment { + results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false}) + lastCalculated = currentTime + totalBytes += intervalBytes + intervalBytes = 0 + } + + if conf.Direction == Upload && currentTime.Sub(startTime) > conf.TestDuration { + break SpeedTestLoop + } + } + + // get last segment + if currentTime.Sub(lastCalculated) > minInterval { + results = append(results, Result{Bytes: intervalBytes, IntervalStart: lastCalculated, IntervalEnd: currentTime, Total: false}) + } + + // get total + totalBytes += intervalBytes + if currentTime.Sub(startTime) > minInterval { + results = append(results, Result{Bytes: totalBytes, IntervalStart: startTime, IntervalEnd: currentTime, Total: true}) + } + + return results, nil +} diff --git a/net/speedtest/speedtest_test.go b/net/speedtest/speedtest_test.go index a413e9efa..55dcbeea1 100644 --- a/net/speedtest/speedtest_test.go +++ b/net/speedtest/speedtest_test.go @@ -1,83 +1,83 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package speedtest - -import ( - "net" - "testing" - "time" -) - -func TestDownload(t *testing.T) { - // start a listener and find the port where the server will be listening. - l, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { l.Close() }) - - serverIP := l.Addr().String() - t.Log("server IP found:", serverIP) - - type state struct { - err error - } - displayResult := func(t *testing.T, r Result, start time.Time) { - t.Helper() - t.Logf("{ Megabytes: %.2f, Start: %.1f, End: %.1f, Total: %t }", r.MegaBytes(), r.IntervalStart.Sub(start).Seconds(), r.IntervalEnd.Sub(start).Seconds(), r.Total) - } - stateChan := make(chan state, 1) - - go func() { - err := Serve(l) - stateChan <- state{err: err} - }() - - // ensure that the test returns an appropriate number of Result structs - expectedLen := int(DefaultDuration.Seconds()) + 1 - - t.Run("download test", func(t *testing.T) { - // conduct a download test - results, err := RunClient(Download, DefaultDuration, serverIP) - - if err != nil { - t.Fatal("download test failed:", err) - } - - if len(results) < expectedLen { - t.Fatalf("download results: expected length: %d, actual length: %d", expectedLen, len(results)) - } - - start := results[0].IntervalStart - for _, result := range results { - displayResult(t, result, start) - } - }) - - t.Run("upload test", func(t *testing.T) { - // conduct an upload test - results, err := RunClient(Upload, DefaultDuration, serverIP) - - if err != nil { - t.Fatal("upload test failed:", err) - } - - if len(results) < expectedLen { - t.Fatalf("upload results: expected length: %d, actual length: %d", expectedLen, len(results)) - } - - start := results[0].IntervalStart - for _, result := range results { - displayResult(t, result, start) - } - }) - - // causes the server goroutine to finish - l.Close() - - testState := <-stateChan - if testState.err != nil { - t.Error("server error:", err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package speedtest + +import ( + "net" + "testing" + "time" +) + +func TestDownload(t *testing.T) { + // start a listener and find the port where the server will be listening. + l, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { l.Close() }) + + serverIP := l.Addr().String() + t.Log("server IP found:", serverIP) + + type state struct { + err error + } + displayResult := func(t *testing.T, r Result, start time.Time) { + t.Helper() + t.Logf("{ Megabytes: %.2f, Start: %.1f, End: %.1f, Total: %t }", r.MegaBytes(), r.IntervalStart.Sub(start).Seconds(), r.IntervalEnd.Sub(start).Seconds(), r.Total) + } + stateChan := make(chan state, 1) + + go func() { + err := Serve(l) + stateChan <- state{err: err} + }() + + // ensure that the test returns an appropriate number of Result structs + expectedLen := int(DefaultDuration.Seconds()) + 1 + + t.Run("download test", func(t *testing.T) { + // conduct a download test + results, err := RunClient(Download, DefaultDuration, serverIP) + + if err != nil { + t.Fatal("download test failed:", err) + } + + if len(results) < expectedLen { + t.Fatalf("download results: expected length: %d, actual length: %d", expectedLen, len(results)) + } + + start := results[0].IntervalStart + for _, result := range results { + displayResult(t, result, start) + } + }) + + t.Run("upload test", func(t *testing.T) { + // conduct an upload test + results, err := RunClient(Upload, DefaultDuration, serverIP) + + if err != nil { + t.Fatal("upload test failed:", err) + } + + if len(results) < expectedLen { + t.Fatalf("upload results: expected length: %d, actual length: %d", expectedLen, len(results)) + } + + start := results[0].IntervalStart + for _, result := range results { + displayResult(t, result, start) + } + }) + + // causes the server goroutine to finish + l.Close() + + testState := <-stateChan + if testState.err != nil { + t.Error("server error:", err) + } +} diff --git a/net/stun/stun.go b/net/stun/stun.go index 81cf9b608..eeac23cbb 100644 --- a/net/stun/stun.go +++ b/net/stun/stun.go @@ -1,312 +1,312 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package STUN generates STUN request packets and parses response packets. -package stun - -import ( - "bytes" - crand "crypto/rand" - "encoding/binary" - "errors" - "hash/crc32" - "net" - "net/netip" -) - -const ( - attrNumSoftware = 0x8022 - attrNumFingerprint = 0x8028 - attrMappedAddress = 0x0001 - attrXorMappedAddress = 0x0020 - // This alternative attribute type is not - // mentioned in the RFC, but the shift into - // the "comprehension-optional" range seems - // like an easy mistake for a server to make. - // And servers appear to send it. - attrXorMappedAddressAlt = 0x8020 - - software = "tailnode" // notably: 8 bytes long, so no padding - bindingRequest = "\x00\x01" - magicCookie = "\x21\x12\xa4\x42" - lenFingerprint = 8 // 2+byte header + 2-byte length + 4-byte crc32 - headerLen = 20 -) - -// TxID is a transaction ID. -type TxID [12]byte - -// NewTxID returns a new random TxID. -func NewTxID() TxID { - var tx TxID - if _, err := crand.Read(tx[:]); err != nil { - panic(err) - } - return tx -} - -// Request generates a binding request STUN packet. -// The transaction ID, tID, should be a random sequence of bytes. -func Request(tID TxID) []byte { - // STUN header, RFC5389 Section 6. - const lenAttrSoftware = 4 + len(software) - b := make([]byte, 0, headerLen+lenAttrSoftware+lenFingerprint) - b = append(b, bindingRequest...) - b = appendU16(b, uint16(lenAttrSoftware+lenFingerprint)) // number of bytes following header - b = append(b, magicCookie...) - b = append(b, tID[:]...) - - // Attribute SOFTWARE, RFC5389 Section 15.5. - b = appendU16(b, attrNumSoftware) - b = appendU16(b, uint16(len(software))) - b = append(b, software...) - - // Attribute FINGERPRINT, RFC5389 Section 15.5. - fp := fingerPrint(b) - b = appendU16(b, attrNumFingerprint) - b = appendU16(b, 4) - b = appendU32(b, fp) - - return b -} - -func fingerPrint(b []byte) uint32 { return crc32.ChecksumIEEE(b) ^ 0x5354554e } - -func appendU16(b []byte, v uint16) []byte { - return append(b, byte(v>>8), byte(v)) -} - -func appendU32(b []byte, v uint32) []byte { - return append(b, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) -} - -// ParseBindingRequest parses a STUN binding request. -// -// It returns an error unless it advertises that it came from -// Tailscale. -func ParseBindingRequest(b []byte) (TxID, error) { - if !Is(b) { - return TxID{}, ErrNotSTUN - } - if string(b[:len(bindingRequest)]) != bindingRequest { - return TxID{}, ErrNotBindingRequest - } - var txID TxID - copy(txID[:], b[8:8+len(txID)]) - var softwareOK bool - var lastAttr uint16 - var gotFP uint32 - if err := foreachAttr(b[headerLen:], func(attrType uint16, a []byte) error { - lastAttr = attrType - if attrType == attrNumSoftware && string(a) == software { - softwareOK = true - } - if attrType == attrNumFingerprint && len(a) == 4 { - gotFP = binary.BigEndian.Uint32(a) - } - return nil - }); err != nil { - return TxID{}, err - } - if !softwareOK { - return TxID{}, ErrWrongSoftware - } - if lastAttr != attrNumFingerprint { - return TxID{}, ErrNoFingerprint - } - wantFP := fingerPrint(b[:len(b)-lenFingerprint]) - if gotFP != wantFP { - return TxID{}, ErrWrongFingerprint - } - return txID, nil -} - -var ( - ErrNotSTUN = errors.New("response is not a STUN packet") - ErrNotSuccessResponse = errors.New("STUN packet is not a response") - ErrMalformedAttrs = errors.New("STUN response has malformed attributes") - ErrNotBindingRequest = errors.New("STUN request not a binding request") - ErrWrongSoftware = errors.New("STUN request came from non-Tailscale software") - ErrNoFingerprint = errors.New("STUN request didn't end in fingerprint") - ErrWrongFingerprint = errors.New("STUN request had bogus fingerprint") -) - -func foreachAttr(b []byte, fn func(attrType uint16, a []byte) error) error { - for len(b) > 0 { - if len(b) < 4 { - return ErrMalformedAttrs - } - attrType := binary.BigEndian.Uint16(b[:2]) - attrLen := int(binary.BigEndian.Uint16(b[2:4])) - attrLenWithPad := (attrLen + 3) &^ 3 - b = b[4:] - if attrLenWithPad > len(b) { - return ErrMalformedAttrs - } - if err := fn(attrType, b[:attrLen]); err != nil { - return err - } - b = b[attrLenWithPad:] - } - return nil -} - -// Response generates a binding response. -func Response(txID TxID, addrPort netip.AddrPort) []byte { - addr := addrPort.Addr() - - var fam byte - if addr.Is4() { - fam = 1 - } else if addr.Is6() { - fam = 2 - } else { - return nil - } - attrsLen := 8 + addr.BitLen()/8 - b := make([]byte, 0, headerLen+attrsLen) - - // Header - b = append(b, 0x01, 0x01) // success - b = appendU16(b, uint16(attrsLen)) - b = append(b, magicCookie...) - b = append(b, txID[:]...) - - // Attributes (well, one) - b = appendU16(b, attrXorMappedAddress) - b = appendU16(b, uint16(4+addr.BitLen()/8)) - b = append(b, - 0, // unused byte - fam) - b = appendU16(b, addrPort.Port()^0x2112) // first half of magicCookie - ipa := addr.As16() - for i, o := range ipa[16-addr.BitLen()/8:] { - if i < 4 { - b = append(b, o^magicCookie[i]) - } else { - b = append(b, o^txID[i-len(magicCookie)]) - } - } - return b -} - -// ParseResponse parses a successful binding response STUN packet. -// The IP address is extracted from the XOR-MAPPED-ADDRESS attribute. -func ParseResponse(b []byte) (tID TxID, addr netip.AddrPort, err error) { - if !Is(b) { - return tID, netip.AddrPort{}, ErrNotSTUN - } - copy(tID[:], b[8:8+len(tID)]) - if b[0] != 0x01 || b[1] != 0x01 { - return tID, netip.AddrPort{}, ErrNotSuccessResponse - } - attrsLen := int(binary.BigEndian.Uint16(b[2:4])) - b = b[headerLen:] // remove STUN header - if attrsLen > len(b) { - return tID, netip.AddrPort{}, ErrMalformedAttrs - } else if len(b) > attrsLen { - b = b[:attrsLen] // trim trailing packet bytes - } - - var fallbackAddr netip.AddrPort - - // Read through the attributes. - // The the addr+port reported by XOR-MAPPED-ADDRESS - // as the canonical value. If the attribute is not - // present but the STUN server responds with - // MAPPED-ADDRESS we fall back to it. - if err := foreachAttr(b, func(attrType uint16, attr []byte) error { - switch attrType { - case attrXorMappedAddress, attrXorMappedAddressAlt: - ipSlice, port, err := xorMappedAddress(tID, attr) - if err != nil { - return err - } - if ip, ok := netip.AddrFromSlice(ipSlice); ok { - addr = netip.AddrPortFrom(ip.Unmap(), port) - } - case attrMappedAddress: - ipSlice, port, err := mappedAddress(attr) - if err != nil { - return ErrMalformedAttrs - } - if ip, ok := netip.AddrFromSlice(ipSlice); ok { - fallbackAddr = netip.AddrPortFrom(ip.Unmap(), port) - } - } - return nil - - }); err != nil { - return TxID{}, netip.AddrPort{}, err - } - - if addr.IsValid() { - return tID, addr, nil - } - if fallbackAddr.IsValid() { - return tID, fallbackAddr, nil - } - return tID, netip.AddrPort{}, ErrMalformedAttrs -} - -func xorMappedAddress(tID TxID, b []byte) (addr []byte, port uint16, err error) { - // XOR-MAPPED-ADDRESS attribute, RFC5389 Section 15.2 - if len(b) < 4 { - return nil, 0, ErrMalformedAttrs - } - xorPort := binary.BigEndian.Uint16(b[2:4]) - addrField := b[4:] - port = xorPort ^ 0x2112 // first half of magicCookie - - addrLen := familyAddrLen(b[1]) - if addrLen == 0 { - return nil, 0, ErrMalformedAttrs - } - if len(addrField) < addrLen { - return nil, 0, ErrMalformedAttrs - } - xorAddr := addrField[:addrLen] - addr = make([]byte, addrLen) - for i := range xorAddr { - if i < len(magicCookie) { - addr[i] = xorAddr[i] ^ magicCookie[i] - } else { - addr[i] = xorAddr[i] ^ tID[i-len(magicCookie)] - } - } - return addr, port, nil -} - -func familyAddrLen(fam byte) int { - switch fam { - case 0x01: // IPv4 - return net.IPv4len - case 0x02: // IPv6 - return net.IPv6len - default: - return 0 - } -} - -func mappedAddress(b []byte) (addr []byte, port uint16, err error) { - if len(b) < 4 { - return nil, 0, ErrMalformedAttrs - } - port = uint16(b[2])<<8 | uint16(b[3]) - addrField := b[4:] - addrLen := familyAddrLen(b[1]) - if addrLen == 0 { - return nil, 0, ErrMalformedAttrs - } - if len(addrField) < addrLen { - return nil, 0, ErrMalformedAttrs - } - return bytes.Clone(addrField[:addrLen]), port, nil -} - -// Is reports whether b is a STUN message. -func Is(b []byte) bool { - return len(b) >= headerLen && - b[0]&0b11000000 == 0 && // top two bits must be zero - string(b[4:8]) == magicCookie -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package STUN generates STUN request packets and parses response packets. +package stun + +import ( + "bytes" + crand "crypto/rand" + "encoding/binary" + "errors" + "hash/crc32" + "net" + "net/netip" +) + +const ( + attrNumSoftware = 0x8022 + attrNumFingerprint = 0x8028 + attrMappedAddress = 0x0001 + attrXorMappedAddress = 0x0020 + // This alternative attribute type is not + // mentioned in the RFC, but the shift into + // the "comprehension-optional" range seems + // like an easy mistake for a server to make. + // And servers appear to send it. + attrXorMappedAddressAlt = 0x8020 + + software = "tailnode" // notably: 8 bytes long, so no padding + bindingRequest = "\x00\x01" + magicCookie = "\x21\x12\xa4\x42" + lenFingerprint = 8 // 2+byte header + 2-byte length + 4-byte crc32 + headerLen = 20 +) + +// TxID is a transaction ID. +type TxID [12]byte + +// NewTxID returns a new random TxID. +func NewTxID() TxID { + var tx TxID + if _, err := crand.Read(tx[:]); err != nil { + panic(err) + } + return tx +} + +// Request generates a binding request STUN packet. +// The transaction ID, tID, should be a random sequence of bytes. +func Request(tID TxID) []byte { + // STUN header, RFC5389 Section 6. + const lenAttrSoftware = 4 + len(software) + b := make([]byte, 0, headerLen+lenAttrSoftware+lenFingerprint) + b = append(b, bindingRequest...) + b = appendU16(b, uint16(lenAttrSoftware+lenFingerprint)) // number of bytes following header + b = append(b, magicCookie...) + b = append(b, tID[:]...) + + // Attribute SOFTWARE, RFC5389 Section 15.5. + b = appendU16(b, attrNumSoftware) + b = appendU16(b, uint16(len(software))) + b = append(b, software...) + + // Attribute FINGERPRINT, RFC5389 Section 15.5. + fp := fingerPrint(b) + b = appendU16(b, attrNumFingerprint) + b = appendU16(b, 4) + b = appendU32(b, fp) + + return b +} + +func fingerPrint(b []byte) uint32 { return crc32.ChecksumIEEE(b) ^ 0x5354554e } + +func appendU16(b []byte, v uint16) []byte { + return append(b, byte(v>>8), byte(v)) +} + +func appendU32(b []byte, v uint32) []byte { + return append(b, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +} + +// ParseBindingRequest parses a STUN binding request. +// +// It returns an error unless it advertises that it came from +// Tailscale. +func ParseBindingRequest(b []byte) (TxID, error) { + if !Is(b) { + return TxID{}, ErrNotSTUN + } + if string(b[:len(bindingRequest)]) != bindingRequest { + return TxID{}, ErrNotBindingRequest + } + var txID TxID + copy(txID[:], b[8:8+len(txID)]) + var softwareOK bool + var lastAttr uint16 + var gotFP uint32 + if err := foreachAttr(b[headerLen:], func(attrType uint16, a []byte) error { + lastAttr = attrType + if attrType == attrNumSoftware && string(a) == software { + softwareOK = true + } + if attrType == attrNumFingerprint && len(a) == 4 { + gotFP = binary.BigEndian.Uint32(a) + } + return nil + }); err != nil { + return TxID{}, err + } + if !softwareOK { + return TxID{}, ErrWrongSoftware + } + if lastAttr != attrNumFingerprint { + return TxID{}, ErrNoFingerprint + } + wantFP := fingerPrint(b[:len(b)-lenFingerprint]) + if gotFP != wantFP { + return TxID{}, ErrWrongFingerprint + } + return txID, nil +} + +var ( + ErrNotSTUN = errors.New("response is not a STUN packet") + ErrNotSuccessResponse = errors.New("STUN packet is not a response") + ErrMalformedAttrs = errors.New("STUN response has malformed attributes") + ErrNotBindingRequest = errors.New("STUN request not a binding request") + ErrWrongSoftware = errors.New("STUN request came from non-Tailscale software") + ErrNoFingerprint = errors.New("STUN request didn't end in fingerprint") + ErrWrongFingerprint = errors.New("STUN request had bogus fingerprint") +) + +func foreachAttr(b []byte, fn func(attrType uint16, a []byte) error) error { + for len(b) > 0 { + if len(b) < 4 { + return ErrMalformedAttrs + } + attrType := binary.BigEndian.Uint16(b[:2]) + attrLen := int(binary.BigEndian.Uint16(b[2:4])) + attrLenWithPad := (attrLen + 3) &^ 3 + b = b[4:] + if attrLenWithPad > len(b) { + return ErrMalformedAttrs + } + if err := fn(attrType, b[:attrLen]); err != nil { + return err + } + b = b[attrLenWithPad:] + } + return nil +} + +// Response generates a binding response. +func Response(txID TxID, addrPort netip.AddrPort) []byte { + addr := addrPort.Addr() + + var fam byte + if addr.Is4() { + fam = 1 + } else if addr.Is6() { + fam = 2 + } else { + return nil + } + attrsLen := 8 + addr.BitLen()/8 + b := make([]byte, 0, headerLen+attrsLen) + + // Header + b = append(b, 0x01, 0x01) // success + b = appendU16(b, uint16(attrsLen)) + b = append(b, magicCookie...) + b = append(b, txID[:]...) + + // Attributes (well, one) + b = appendU16(b, attrXorMappedAddress) + b = appendU16(b, uint16(4+addr.BitLen()/8)) + b = append(b, + 0, // unused byte + fam) + b = appendU16(b, addrPort.Port()^0x2112) // first half of magicCookie + ipa := addr.As16() + for i, o := range ipa[16-addr.BitLen()/8:] { + if i < 4 { + b = append(b, o^magicCookie[i]) + } else { + b = append(b, o^txID[i-len(magicCookie)]) + } + } + return b +} + +// ParseResponse parses a successful binding response STUN packet. +// The IP address is extracted from the XOR-MAPPED-ADDRESS attribute. +func ParseResponse(b []byte) (tID TxID, addr netip.AddrPort, err error) { + if !Is(b) { + return tID, netip.AddrPort{}, ErrNotSTUN + } + copy(tID[:], b[8:8+len(tID)]) + if b[0] != 0x01 || b[1] != 0x01 { + return tID, netip.AddrPort{}, ErrNotSuccessResponse + } + attrsLen := int(binary.BigEndian.Uint16(b[2:4])) + b = b[headerLen:] // remove STUN header + if attrsLen > len(b) { + return tID, netip.AddrPort{}, ErrMalformedAttrs + } else if len(b) > attrsLen { + b = b[:attrsLen] // trim trailing packet bytes + } + + var fallbackAddr netip.AddrPort + + // Read through the attributes. + // The the addr+port reported by XOR-MAPPED-ADDRESS + // as the canonical value. If the attribute is not + // present but the STUN server responds with + // MAPPED-ADDRESS we fall back to it. + if err := foreachAttr(b, func(attrType uint16, attr []byte) error { + switch attrType { + case attrXorMappedAddress, attrXorMappedAddressAlt: + ipSlice, port, err := xorMappedAddress(tID, attr) + if err != nil { + return err + } + if ip, ok := netip.AddrFromSlice(ipSlice); ok { + addr = netip.AddrPortFrom(ip.Unmap(), port) + } + case attrMappedAddress: + ipSlice, port, err := mappedAddress(attr) + if err != nil { + return ErrMalformedAttrs + } + if ip, ok := netip.AddrFromSlice(ipSlice); ok { + fallbackAddr = netip.AddrPortFrom(ip.Unmap(), port) + } + } + return nil + + }); err != nil { + return TxID{}, netip.AddrPort{}, err + } + + if addr.IsValid() { + return tID, addr, nil + } + if fallbackAddr.IsValid() { + return tID, fallbackAddr, nil + } + return tID, netip.AddrPort{}, ErrMalformedAttrs +} + +func xorMappedAddress(tID TxID, b []byte) (addr []byte, port uint16, err error) { + // XOR-MAPPED-ADDRESS attribute, RFC5389 Section 15.2 + if len(b) < 4 { + return nil, 0, ErrMalformedAttrs + } + xorPort := binary.BigEndian.Uint16(b[2:4]) + addrField := b[4:] + port = xorPort ^ 0x2112 // first half of magicCookie + + addrLen := familyAddrLen(b[1]) + if addrLen == 0 { + return nil, 0, ErrMalformedAttrs + } + if len(addrField) < addrLen { + return nil, 0, ErrMalformedAttrs + } + xorAddr := addrField[:addrLen] + addr = make([]byte, addrLen) + for i := range xorAddr { + if i < len(magicCookie) { + addr[i] = xorAddr[i] ^ magicCookie[i] + } else { + addr[i] = xorAddr[i] ^ tID[i-len(magicCookie)] + } + } + return addr, port, nil +} + +func familyAddrLen(fam byte) int { + switch fam { + case 0x01: // IPv4 + return net.IPv4len + case 0x02: // IPv6 + return net.IPv6len + default: + return 0 + } +} + +func mappedAddress(b []byte) (addr []byte, port uint16, err error) { + if len(b) < 4 { + return nil, 0, ErrMalformedAttrs + } + port = uint16(b[2])<<8 | uint16(b[3]) + addrField := b[4:] + addrLen := familyAddrLen(b[1]) + if addrLen == 0 { + return nil, 0, ErrMalformedAttrs + } + if len(addrField) < addrLen { + return nil, 0, ErrMalformedAttrs + } + return bytes.Clone(addrField[:addrLen]), port, nil +} + +// Is reports whether b is a STUN message. +func Is(b []byte) bool { + return len(b) >= headerLen && + b[0]&0b11000000 == 0 && // top two bits must be zero + string(b[4:8]) == magicCookie +} diff --git a/net/stun/stun_fuzzer.go b/net/stun/stun_fuzzer.go index 9ddb41895..6f0c9e3b0 100644 --- a/net/stun/stun_fuzzer.go +++ b/net/stun/stun_fuzzer.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -//go:build gofuzz - -package stun - -func FuzzStunParser(data []byte) int { - _, _, _ = ParseResponse(data) - - _, _ = ParseBindingRequest(data) - return 1 -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +//go:build gofuzz + +package stun + +func FuzzStunParser(data []byte) int { + _, _, _ = ParseResponse(data) + + _, _ = ParseBindingRequest(data) + return 1 +} diff --git a/net/tcpinfo/tcpinfo.go b/net/tcpinfo/tcpinfo.go index adc40ca37..a757add9f 100644 --- a/net/tcpinfo/tcpinfo.go +++ b/net/tcpinfo/tcpinfo.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package tcpinfo provides platform-agnostic accessors to information about a -// TCP connection (e.g. RTT, MSS, etc.). -package tcpinfo - -import ( - "errors" - "net" - "time" -) - -var ( - ErrNotTCP = errors.New("tcpinfo: not a TCP conn") - ErrUnimplemented = errors.New("tcpinfo: unimplemented") -) - -// RTT returns the RTT for the given net.Conn. -// -// If the net.Conn is not a *net.TCPConn and cannot be unwrapped into one, then -// ErrNotTCP will be returned. If retrieving the RTT is not supported on the -// current platform, ErrUnimplemented will be returned. -func RTT(conn net.Conn) (time.Duration, error) { - tcpConn, err := unwrap(conn) - if err != nil { - return 0, err - } - - return rttImpl(tcpConn) -} - -// netConner is implemented by crypto/tls.Conn to unwrap into an underlying -// net.Conn. -type netConner interface { - NetConn() net.Conn -} - -// unwrap attempts to unwrap a net.Conn into an underlying *net.TCPConn -func unwrap(nc net.Conn) (*net.TCPConn, error) { - for { - switch v := nc.(type) { - case *net.TCPConn: - return v, nil - case netConner: - nc = v.NetConn() - default: - return nil, ErrNotTCP - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tcpinfo provides platform-agnostic accessors to information about a +// TCP connection (e.g. RTT, MSS, etc.). +package tcpinfo + +import ( + "errors" + "net" + "time" +) + +var ( + ErrNotTCP = errors.New("tcpinfo: not a TCP conn") + ErrUnimplemented = errors.New("tcpinfo: unimplemented") +) + +// RTT returns the RTT for the given net.Conn. +// +// If the net.Conn is not a *net.TCPConn and cannot be unwrapped into one, then +// ErrNotTCP will be returned. If retrieving the RTT is not supported on the +// current platform, ErrUnimplemented will be returned. +func RTT(conn net.Conn) (time.Duration, error) { + tcpConn, err := unwrap(conn) + if err != nil { + return 0, err + } + + return rttImpl(tcpConn) +} + +// netConner is implemented by crypto/tls.Conn to unwrap into an underlying +// net.Conn. +type netConner interface { + NetConn() net.Conn +} + +// unwrap attempts to unwrap a net.Conn into an underlying *net.TCPConn +func unwrap(nc net.Conn) (*net.TCPConn, error) { + for { + switch v := nc.(type) { + case *net.TCPConn: + return v, nil + case netConner: + nc = v.NetConn() + default: + return nil, ErrNotTCP + } + } +} diff --git a/net/tcpinfo/tcpinfo_darwin.go b/net/tcpinfo/tcpinfo_darwin.go index bc4ac08b3..53fa22fbf 100644 --- a/net/tcpinfo/tcpinfo_darwin.go +++ b/net/tcpinfo/tcpinfo_darwin.go @@ -1,33 +1,33 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tcpinfo - -import ( - "net" - "time" - - "golang.org/x/sys/unix" -) - -func rttImpl(conn *net.TCPConn) (time.Duration, error) { - rawConn, err := conn.SyscallConn() - if err != nil { - return 0, err - } - - var ( - tcpInfo *unix.TCPConnectionInfo - sysErr error - ) - err = rawConn.Control(func(fd uintptr) { - tcpInfo, sysErr = unix.GetsockoptTCPConnectionInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO) - }) - if err != nil { - return 0, err - } else if sysErr != nil { - return 0, sysErr - } - - return time.Duration(tcpInfo.Rttcur) * time.Millisecond, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tcpinfo + +import ( + "net" + "time" + + "golang.org/x/sys/unix" +) + +func rttImpl(conn *net.TCPConn) (time.Duration, error) { + rawConn, err := conn.SyscallConn() + if err != nil { + return 0, err + } + + var ( + tcpInfo *unix.TCPConnectionInfo + sysErr error + ) + err = rawConn.Control(func(fd uintptr) { + tcpInfo, sysErr = unix.GetsockoptTCPConnectionInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_CONNECTION_INFO) + }) + if err != nil { + return 0, err + } else if sysErr != nil { + return 0, sysErr + } + + return time.Duration(tcpInfo.Rttcur) * time.Millisecond, nil +} diff --git a/net/tcpinfo/tcpinfo_linux.go b/net/tcpinfo/tcpinfo_linux.go index 5d86055bb..885d462c9 100644 --- a/net/tcpinfo/tcpinfo_linux.go +++ b/net/tcpinfo/tcpinfo_linux.go @@ -1,33 +1,33 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tcpinfo - -import ( - "net" - "time" - - "golang.org/x/sys/unix" -) - -func rttImpl(conn *net.TCPConn) (time.Duration, error) { - rawConn, err := conn.SyscallConn() - if err != nil { - return 0, err - } - - var ( - tcpInfo *unix.TCPInfo - sysErr error - ) - err = rawConn.Control(func(fd uintptr) { - tcpInfo, sysErr = unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO) - }) - if err != nil { - return 0, err - } else if sysErr != nil { - return 0, sysErr - } - - return time.Duration(tcpInfo.Rtt) * time.Microsecond, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tcpinfo + +import ( + "net" + "time" + + "golang.org/x/sys/unix" +) + +func rttImpl(conn *net.TCPConn) (time.Duration, error) { + rawConn, err := conn.SyscallConn() + if err != nil { + return 0, err + } + + var ( + tcpInfo *unix.TCPInfo + sysErr error + ) + err = rawConn.Control(func(fd uintptr) { + tcpInfo, sysErr = unix.GetsockoptTCPInfo(int(fd), unix.IPPROTO_TCP, unix.TCP_INFO) + }) + if err != nil { + return 0, err + } else if sysErr != nil { + return 0, sysErr + } + + return time.Duration(tcpInfo.Rtt) * time.Microsecond, nil +} diff --git a/net/tcpinfo/tcpinfo_other.go b/net/tcpinfo/tcpinfo_other.go index f219cda1b..be45523ae 100644 --- a/net/tcpinfo/tcpinfo_other.go +++ b/net/tcpinfo/tcpinfo_other.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux && !darwin - -package tcpinfo - -import ( - "net" - "time" -) - -func rttImpl(conn *net.TCPConn) (time.Duration, error) { - return 0, ErrUnimplemented -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !darwin + +package tcpinfo + +import ( + "net" + "time" +) + +func rttImpl(conn *net.TCPConn) (time.Duration, error) { + return 0, ErrUnimplemented +} diff --git a/net/tlsdial/deps_test.go b/net/tlsdial/deps_test.go index 750cb300a..7a93899c2 100644 --- a/net/tlsdial/deps_test.go +++ b/net/tlsdial/deps_test.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build for_go_mod_tidy_only - -package tlsdial - -import _ "filippo.io/mkcert" +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build for_go_mod_tidy_only + +package tlsdial + +import _ "filippo.io/mkcert" diff --git a/net/tsdial/dnsmap_test.go b/net/tsdial/dnsmap_test.go index f846b853e..43461a135 100644 --- a/net/tsdial/dnsmap_test.go +++ b/net/tsdial/dnsmap_test.go @@ -1,125 +1,125 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tsdial - -import ( - "net/netip" - "reflect" - "testing" - - "tailscale.com/tailcfg" - "tailscale.com/types/netmap" -) - -func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView { - nv := make([]tailcfg.NodeView, len(v)) - for i, n := range v { - nv[i] = n.View() - } - return nv -} - -func TestDNSMapFromNetworkMap(t *testing.T) { - pfx := netip.MustParsePrefix - ip := netip.MustParseAddr - tests := []struct { - name string - nm *netmap.NetworkMap - want dnsMap - }{ - { - name: "self", - nm: &netmap.NetworkMap{ - Name: "foo.tailnet", - SelfNode: (&tailcfg.Node{ - Addresses: []netip.Prefix{ - pfx("100.102.103.104/32"), - pfx("100::123/128"), - }, - }).View(), - }, - want: dnsMap{ - "foo": ip("100.102.103.104"), - "foo.tailnet": ip("100.102.103.104"), - }, - }, - { - name: "self_and_peers", - nm: &netmap.NetworkMap{ - Name: "foo.tailnet", - SelfNode: (&tailcfg.Node{ - Addresses: []netip.Prefix{ - pfx("100.102.103.104/32"), - pfx("100::123/128"), - }, - }).View(), - Peers: []tailcfg.NodeView{ - (&tailcfg.Node{ - Name: "a.tailnet", - Addresses: []netip.Prefix{ - pfx("100.0.0.201/32"), - pfx("100::201/128"), - }, - }).View(), - (&tailcfg.Node{ - Name: "b.tailnet", - Addresses: []netip.Prefix{ - pfx("100::202/128"), - }, - }).View(), - }, - }, - want: dnsMap{ - "foo": ip("100.102.103.104"), - "foo.tailnet": ip("100.102.103.104"), - "a": ip("100.0.0.201"), - "a.tailnet": ip("100.0.0.201"), - "b": ip("100::202"), - "b.tailnet": ip("100::202"), - }, - }, - { - name: "self_has_v6_only", - nm: &netmap.NetworkMap{ - Name: "foo.tailnet", - SelfNode: (&tailcfg.Node{ - Addresses: []netip.Prefix{ - pfx("100::123/128"), - }, - }).View(), - Peers: nodeViews([]*tailcfg.Node{ - { - Name: "a.tailnet", - Addresses: []netip.Prefix{ - pfx("100.0.0.201/32"), - pfx("100::201/128"), - }, - }, - { - Name: "b.tailnet", - Addresses: []netip.Prefix{ - pfx("100::202/128"), - }, - }, - }), - }, - want: dnsMap{ - "foo": ip("100::123"), - "foo.tailnet": ip("100::123"), - "a": ip("100::201"), - "a.tailnet": ip("100::201"), - "b": ip("100::202"), - "b.tailnet": ip("100::202"), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := dnsMapFromNetworkMap(tt.nm) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("mismatch:\n got %v\nwant %v\n", got, tt.want) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsdial + +import ( + "net/netip" + "reflect" + "testing" + + "tailscale.com/tailcfg" + "tailscale.com/types/netmap" +) + +func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView { + nv := make([]tailcfg.NodeView, len(v)) + for i, n := range v { + nv[i] = n.View() + } + return nv +} + +func TestDNSMapFromNetworkMap(t *testing.T) { + pfx := netip.MustParsePrefix + ip := netip.MustParseAddr + tests := []struct { + name string + nm *netmap.NetworkMap + want dnsMap + }{ + { + name: "self", + nm: &netmap.NetworkMap{ + Name: "foo.tailnet", + SelfNode: (&tailcfg.Node{ + Addresses: []netip.Prefix{ + pfx("100.102.103.104/32"), + pfx("100::123/128"), + }, + }).View(), + }, + want: dnsMap{ + "foo": ip("100.102.103.104"), + "foo.tailnet": ip("100.102.103.104"), + }, + }, + { + name: "self_and_peers", + nm: &netmap.NetworkMap{ + Name: "foo.tailnet", + SelfNode: (&tailcfg.Node{ + Addresses: []netip.Prefix{ + pfx("100.102.103.104/32"), + pfx("100::123/128"), + }, + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + Name: "a.tailnet", + Addresses: []netip.Prefix{ + pfx("100.0.0.201/32"), + pfx("100::201/128"), + }, + }).View(), + (&tailcfg.Node{ + Name: "b.tailnet", + Addresses: []netip.Prefix{ + pfx("100::202/128"), + }, + }).View(), + }, + }, + want: dnsMap{ + "foo": ip("100.102.103.104"), + "foo.tailnet": ip("100.102.103.104"), + "a": ip("100.0.0.201"), + "a.tailnet": ip("100.0.0.201"), + "b": ip("100::202"), + "b.tailnet": ip("100::202"), + }, + }, + { + name: "self_has_v6_only", + nm: &netmap.NetworkMap{ + Name: "foo.tailnet", + SelfNode: (&tailcfg.Node{ + Addresses: []netip.Prefix{ + pfx("100::123/128"), + }, + }).View(), + Peers: nodeViews([]*tailcfg.Node{ + { + Name: "a.tailnet", + Addresses: []netip.Prefix{ + pfx("100.0.0.201/32"), + pfx("100::201/128"), + }, + }, + { + Name: "b.tailnet", + Addresses: []netip.Prefix{ + pfx("100::202/128"), + }, + }, + }), + }, + want: dnsMap{ + "foo": ip("100::123"), + "foo.tailnet": ip("100::123"), + "a": ip("100::201"), + "a.tailnet": ip("100::201"), + "b": ip("100::202"), + "b.tailnet": ip("100::202"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := dnsMapFromNetworkMap(tt.nm) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("mismatch:\n got %v\nwant %v\n", got, tt.want) + } + }) + } +} diff --git a/net/tsdial/dohclient.go b/net/tsdial/dohclient.go index 64c127fd3..d830398cd 100644 --- a/net/tsdial/dohclient.go +++ b/net/tsdial/dohclient.go @@ -1,100 +1,100 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tsdial - -import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "net" - "net/http" - "time" - - "tailscale.com/net/dnscache" -) - -// dohConn is a net.PacketConn suitable for returning from -// net.Dialer.Dial to send DNS queries over PeerAPI to exit nodes' -// ExitDNS DoH proxy service. -type dohConn struct { - ctx context.Context - baseURL string - hc *http.Client // if nil, default is used - dnsCache *dnscache.MessageCache - - rbuf bytes.Buffer -} - -var ( - _ net.Conn = (*dohConn)(nil) - _ net.PacketConn = (*dohConn)(nil) // be a PacketConn to change net.Resolver semantics -) - -func (*dohConn) Close() error { return nil } -func (*dohConn) LocalAddr() net.Addr { return todoAddr{} } -func (*dohConn) RemoteAddr() net.Addr { return todoAddr{} } -func (*dohConn) SetDeadline(t time.Time) error { return nil } -func (*dohConn) SetReadDeadline(t time.Time) error { return nil } -func (*dohConn) SetWriteDeadline(t time.Time) error { return nil } - -func (c *dohConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - return c.Write(p) -} - -func (c *dohConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - n, err = c.Read(p) - return n, todoAddr{}, err -} - -func (c *dohConn) Read(p []byte) (n int, err error) { - return c.rbuf.Read(p) -} - -func (c *dohConn) Write(packet []byte) (n int, err error) { - if c.dnsCache != nil { - err := c.dnsCache.ReplyFromCache(&c.rbuf, packet) - if err == nil { - // Cache hit. - // TODO(bradfitz): add clientmetric - return len(packet), nil - } - c.rbuf.Reset() - } - req, err := http.NewRequestWithContext(c.ctx, "POST", c.baseURL, bytes.NewReader(packet)) - if err != nil { - return 0, err - } - const dohType = "application/dns-message" - req.Header.Set("Content-Type", dohType) - hc := c.hc - if hc == nil { - hc = http.DefaultClient - } - hres, err := hc.Do(req) - if err != nil { - return 0, err - } - defer hres.Body.Close() - if hres.StatusCode != 200 { - return 0, errors.New(hres.Status) - } - if ct := hres.Header.Get("Content-Type"); ct != dohType { - return 0, fmt.Errorf("unexpected response Content-Type %q", ct) - } - _, err = io.Copy(&c.rbuf, hres.Body) - if err != nil { - return 0, err - } - if c.dnsCache != nil { - c.dnsCache.AddCacheEntry(packet, c.rbuf.Bytes()) - } - return len(packet), nil -} - -type todoAddr struct{} - -func (todoAddr) Network() string { return "unused" } -func (todoAddr) String() string { return "unused-todoAddr" } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsdial + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "time" + + "tailscale.com/net/dnscache" +) + +// dohConn is a net.PacketConn suitable for returning from +// net.Dialer.Dial to send DNS queries over PeerAPI to exit nodes' +// ExitDNS DoH proxy service. +type dohConn struct { + ctx context.Context + baseURL string + hc *http.Client // if nil, default is used + dnsCache *dnscache.MessageCache + + rbuf bytes.Buffer +} + +var ( + _ net.Conn = (*dohConn)(nil) + _ net.PacketConn = (*dohConn)(nil) // be a PacketConn to change net.Resolver semantics +) + +func (*dohConn) Close() error { return nil } +func (*dohConn) LocalAddr() net.Addr { return todoAddr{} } +func (*dohConn) RemoteAddr() net.Addr { return todoAddr{} } +func (*dohConn) SetDeadline(t time.Time) error { return nil } +func (*dohConn) SetReadDeadline(t time.Time) error { return nil } +func (*dohConn) SetWriteDeadline(t time.Time) error { return nil } + +func (c *dohConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + return c.Write(p) +} + +func (c *dohConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(p) + return n, todoAddr{}, err +} + +func (c *dohConn) Read(p []byte) (n int, err error) { + return c.rbuf.Read(p) +} + +func (c *dohConn) Write(packet []byte) (n int, err error) { + if c.dnsCache != nil { + err := c.dnsCache.ReplyFromCache(&c.rbuf, packet) + if err == nil { + // Cache hit. + // TODO(bradfitz): add clientmetric + return len(packet), nil + } + c.rbuf.Reset() + } + req, err := http.NewRequestWithContext(c.ctx, "POST", c.baseURL, bytes.NewReader(packet)) + if err != nil { + return 0, err + } + const dohType = "application/dns-message" + req.Header.Set("Content-Type", dohType) + hc := c.hc + if hc == nil { + hc = http.DefaultClient + } + hres, err := hc.Do(req) + if err != nil { + return 0, err + } + defer hres.Body.Close() + if hres.StatusCode != 200 { + return 0, errors.New(hres.Status) + } + if ct := hres.Header.Get("Content-Type"); ct != dohType { + return 0, fmt.Errorf("unexpected response Content-Type %q", ct) + } + _, err = io.Copy(&c.rbuf, hres.Body) + if err != nil { + return 0, err + } + if c.dnsCache != nil { + c.dnsCache.AddCacheEntry(packet, c.rbuf.Bytes()) + } + return len(packet), nil +} + +type todoAddr struct{} + +func (todoAddr) Network() string { return "unused" } +func (todoAddr) String() string { return "unused-todoAddr" } diff --git a/net/tsdial/dohclient_test.go b/net/tsdial/dohclient_test.go index 41a66f8f7..23255769f 100644 --- a/net/tsdial/dohclient_test.go +++ b/net/tsdial/dohclient_test.go @@ -1,31 +1,31 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tsdial - -import ( - "context" - "flag" - "net" - "testing" - "time" -) - -var dohBase = flag.String("doh-base", "", "DoH base URL for manual DoH tests; e.g. \"http://100.68.82.120:47830/dns-query\"") - -func TestDoHResolve(t *testing.T) { - if *dohBase == "" { - t.Skip("skipping manual test without --doh-base= set") - } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - var r net.Resolver - r.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { - return &dohConn{ctx: ctx, baseURL: *dohBase}, nil - } - addrs, err := r.LookupIP(ctx, "ip4", "google.com.") - if err != nil { - t.Fatal(err) - } - t.Logf("Got: %q", addrs) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsdial + +import ( + "context" + "flag" + "net" + "testing" + "time" +) + +var dohBase = flag.String("doh-base", "", "DoH base URL for manual DoH tests; e.g. \"http://100.68.82.120:47830/dns-query\"") + +func TestDoHResolve(t *testing.T) { + if *dohBase == "" { + t.Skip("skipping manual test without --doh-base= set") + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + var r net.Resolver + r.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { + return &dohConn{ctx: ctx, baseURL: *dohBase}, nil + } + addrs, err := r.LookupIP(ctx, "ip4", "google.com.") + if err != nil { + t.Fatal(err) + } + t.Logf("Got: %q", addrs) +} diff --git a/net/tshttpproxy/mksyscall.go b/net/tshttpproxy/mksyscall.go index 467dc4917..f8fdae89b 100644 --- a/net/tshttpproxy/mksyscall.go +++ b/net/tshttpproxy/mksyscall.go @@ -1,11 +1,11 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tshttpproxy - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go - -//sys globalFree(hglobal winHGlobal) (err error) [failretval==0] = kernel32.GlobalFree -//sys winHTTPCloseHandle(whi winHTTPInternet) (err error) [failretval==0] = winhttp.WinHttpCloseHandle -//sys winHTTPGetProxyForURL(whi winHTTPInternet, url *uint16, options *winHTTPAutoProxyOptions, proxyInfo *winHTTPProxyInfo) (err error) [failretval==0] = winhttp.WinHttpGetProxyForUrl -//sys winHTTPOpen(agent *uint16, accessType uint32, proxy *uint16, proxyBypass *uint16, flags uint32) (whi winHTTPInternet, err error) [failretval==0] = winhttp.WinHttpOpen +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tshttpproxy + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go + +//sys globalFree(hglobal winHGlobal) (err error) [failretval==0] = kernel32.GlobalFree +//sys winHTTPCloseHandle(whi winHTTPInternet) (err error) [failretval==0] = winhttp.WinHttpCloseHandle +//sys winHTTPGetProxyForURL(whi winHTTPInternet, url *uint16, options *winHTTPAutoProxyOptions, proxyInfo *winHTTPProxyInfo) (err error) [failretval==0] = winhttp.WinHttpGetProxyForUrl +//sys winHTTPOpen(agent *uint16, accessType uint32, proxy *uint16, proxyBypass *uint16, flags uint32) (whi winHTTPInternet, err error) [failretval==0] = winhttp.WinHttpOpen diff --git a/net/tshttpproxy/tshttpproxy_linux.go b/net/tshttpproxy/tshttpproxy_linux.go index 09019893a..b241c256d 100644 --- a/net/tshttpproxy/tshttpproxy_linux.go +++ b/net/tshttpproxy/tshttpproxy_linux.go @@ -1,24 +1,24 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package tshttpproxy - -import ( - "net/http" - "net/url" - - "tailscale.com/version/distro" -) - -func init() { - sysProxyFromEnv = linuxSysProxyFromEnv -} - -func linuxSysProxyFromEnv(req *http.Request) (*url.URL, error) { - if distro.Get() == distro.Synology { - return synologyProxyFromConfigCached(req) - } - return nil, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package tshttpproxy + +import ( + "net/http" + "net/url" + + "tailscale.com/version/distro" +) + +func init() { + sysProxyFromEnv = linuxSysProxyFromEnv +} + +func linuxSysProxyFromEnv(req *http.Request) (*url.URL, error) { + if distro.Get() == distro.Synology { + return synologyProxyFromConfigCached(req) + } + return nil, nil +} diff --git a/net/tshttpproxy/tshttpproxy_synology_test.go b/net/tshttpproxy/tshttpproxy_synology_test.go index e11c9d059..3061740f3 100644 --- a/net/tshttpproxy/tshttpproxy_synology_test.go +++ b/net/tshttpproxy/tshttpproxy_synology_test.go @@ -1,376 +1,376 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package tshttpproxy - -import ( - "errors" - "fmt" - "io" - "net/http" - "net/url" - "os" - "path/filepath" - "strings" - "testing" - "time" - - "tailscale.com/tstest" -) - -func TestSynologyProxyFromConfigCached(t *testing.T) { - req, err := http.NewRequest("GET", "http://example.org/", nil) - if err != nil { - t.Fatal(err) - } - - tstest.Replace(t, &synologyProxyConfigPath, filepath.Join(t.TempDir(), "proxy.conf")) - - t.Run("no config file", func(t *testing.T) { - if _, err := os.Stat(synologyProxyConfigPath); err == nil { - t.Fatalf("%s must not exist for this test", synologyProxyConfigPath) - } - - cache.updated = time.Time{} - cache.httpProxy = nil - cache.httpsProxy = nil - - if val, err := synologyProxyFromConfigCached(req); val != nil || err != nil { - t.Fatalf("got %s, %v; want nil, nil", val, err) - } - - if got, want := cache.updated, time.Unix(0, 0); got != want { - t.Fatalf("got %s, want %s", got, want) - } - if cache.httpProxy != nil { - t.Fatalf("got %s, want nil", cache.httpProxy) - } - if cache.httpsProxy != nil { - t.Fatalf("got %s, want nil", cache.httpsProxy) - } - }) - - t.Run("config file updated", func(t *testing.T) { - cache.updated = time.Now() - cache.httpProxy = nil - cache.httpsProxy = nil - - if err := os.WriteFile(synologyProxyConfigPath, []byte(` -proxy_enabled=yes -http_host=10.0.0.55 -http_port=80 -https_host=10.0.0.66 -https_port=443 - `), 0600); err != nil { - t.Fatal(err) - } - - val, err := synologyProxyFromConfigCached(req) - if err != nil { - t.Fatal(err) - } - - if cache.httpProxy == nil { - t.Fatal("http proxy was not cached") - } - if cache.httpsProxy == nil { - t.Fatal("https proxy was not cached") - } - - if want := urlMustParse("http://10.0.0.55:80"); val.String() != want.String() { - t.Fatalf("got %s; want %s", val, want) - } - }) - - t.Run("config file removed", func(t *testing.T) { - cache.updated = time.Now() - cache.httpProxy = urlMustParse("http://127.0.0.1/") - cache.httpsProxy = urlMustParse("http://127.0.0.1/") - - if err := os.Remove(synologyProxyConfigPath); err != nil && !os.IsNotExist(err) { - t.Fatal(err) - } - - val, err := synologyProxyFromConfigCached(req) - if err != nil { - t.Fatal(err) - } - if val != nil { - t.Fatalf("got %s; want nil", val) - } - if cache.httpProxy != nil { - t.Fatalf("got %s, want nil", cache.httpProxy) - } - if cache.httpsProxy != nil { - t.Fatalf("got %s, want nil", cache.httpsProxy) - } - }) - - t.Run("picks proxy from request scheme", func(t *testing.T) { - cache.updated = time.Now() - cache.httpProxy = nil - cache.httpsProxy = nil - - if err := os.WriteFile(synologyProxyConfigPath, []byte(` -proxy_enabled=yes -http_host=10.0.0.55 -http_port=80 -https_host=10.0.0.66 -https_port=443 - `), 0600); err != nil { - t.Fatal(err) - } - - httpReq, err := http.NewRequest("GET", "http://example.com", nil) - if err != nil { - t.Fatal(err) - } - val, err := synologyProxyFromConfigCached(httpReq) - if err != nil { - t.Fatal(err) - } - if val == nil { - t.Fatalf("got nil, want an http URL") - } - if got, want := val.String(), "http://10.0.0.55:80"; got != want { - t.Fatalf("got %q, want %q", got, want) - } - - httpsReq, err := http.NewRequest("GET", "https://example.com", nil) - if err != nil { - t.Fatal(err) - } - val, err = synologyProxyFromConfigCached(httpsReq) - if err != nil { - t.Fatal(err) - } - if val == nil { - t.Fatalf("got nil, want an http URL") - } - if got, want := val.String(), "http://10.0.0.66:443"; got != want { - t.Fatalf("got %q, want %q", got, want) - } - }) -} - -func TestSynologyProxiesFromConfig(t *testing.T) { - var ( - openReader io.ReadCloser - openErr error - ) - tstest.Replace(t, &openSynologyProxyConf, func() (io.ReadCloser, error) { - return openReader, openErr - }) - - t.Run("with config", func(t *testing.T) { - mc := &mustCloser{Reader: strings.NewReader(` -proxy_user=foo -proxy_pwd=bar -proxy_enabled=yes -adv_enabled=yes -bypass_enabled=yes -auth_enabled=yes -https_host=10.0.0.66 -https_port=8443 -http_host=10.0.0.55 -http_port=80 - `)} - defer mc.check(t) - openReader = mc - - httpProxy, httpsProxy, err := synologyProxiesFromConfig() - - if got, want := err, openErr; got != want { - t.Fatalf("got %s, want %s", got, want) - } - - if got, want := httpsProxy, urlMustParse("http://foo:bar@10.0.0.66:8443"); got.String() != want.String() { - t.Fatalf("got %s, want %s", got, want) - } - - if got, want := err, openErr; got != want { - t.Fatalf("got %s, want %s", got, want) - } - - if got, want := httpProxy, urlMustParse("http://foo:bar@10.0.0.55:80"); got.String() != want.String() { - t.Fatalf("got %s, want %s", got, want) - } - - }) - - t.Run("nonexistent config", func(t *testing.T) { - openReader = nil - openErr = os.ErrNotExist - - httpProxy, httpsProxy, err := synologyProxiesFromConfig() - if err != nil { - t.Fatalf("expected no error, got %s", err) - } - if httpProxy != nil { - t.Fatalf("expected no url, got %s", httpProxy) - } - if httpsProxy != nil { - t.Fatalf("expected no url, got %s", httpsProxy) - } - }) - - t.Run("error opening config", func(t *testing.T) { - openReader = nil - openErr = errors.New("example error") - - httpProxy, httpsProxy, err := synologyProxiesFromConfig() - if err != openErr { - t.Fatalf("expected %s, got %s", openErr, err) - } - if httpProxy != nil { - t.Fatalf("expected no url, got %s", httpProxy) - } - if httpsProxy != nil { - t.Fatalf("expected no url, got %s", httpsProxy) - } - }) - -} - -func TestParseSynologyConfig(t *testing.T) { - cases := map[string]struct { - input string - httpProxy *url.URL - httpsProxy *url.URL - err error - }{ - "populated": { - input: ` -proxy_user=foo -proxy_pwd=bar -proxy_enabled=yes -adv_enabled=yes -bypass_enabled=yes -auth_enabled=yes -https_host=10.0.0.66 -https_port=8443 -http_host=10.0.0.55 -http_port=80 -`, - httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"), - httpsProxy: urlMustParse("http://foo:bar@10.0.0.66:8443"), - err: nil, - }, - "no-auth": { - input: ` -proxy_user=foo -proxy_pwd=bar -proxy_enabled=yes -adv_enabled=yes -bypass_enabled=yes -auth_enabled=no -https_host=10.0.0.66 -https_port=8443 -http_host=10.0.0.55 -http_port=80 -`, - httpProxy: urlMustParse("http://10.0.0.55:80"), - httpsProxy: urlMustParse("http://10.0.0.66:8443"), - err: nil, - }, - "http-only": { - input: ` -proxy_user=foo -proxy_pwd=bar -proxy_enabled=yes -adv_enabled=yes -bypass_enabled=yes -auth_enabled=yes -https_host= -https_port=8443 -http_host=10.0.0.55 -http_port=80 -`, - httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"), - httpsProxy: nil, - err: nil, - }, - "empty": { - input: ` -proxy_user= -proxy_pwd= -proxy_enabled= -adv_enabled= -bypass_enabled= -auth_enabled= -https_host= -https_port= -http_host= -http_port= -`, - httpProxy: nil, - httpsProxy: nil, - err: nil, - }, - } - - for name, example := range cases { - t.Run(name, func(t *testing.T) { - httpProxy, httpsProxy, err := parseSynologyConfig(strings.NewReader(example.input)) - if err != example.err { - t.Fatal(err) - } - if example.err != nil { - return - } - - if example.httpProxy == nil && httpProxy != nil { - t.Fatalf("got %s, want nil", httpProxy) - } - - if example.httpProxy != nil { - if httpProxy == nil { - t.Fatalf("got nil, want %s", example.httpProxy) - } - - if got, want := example.httpProxy.String(), httpProxy.String(); got != want { - t.Fatalf("got %s, want %s", got, want) - } - } - - if example.httpsProxy == nil && httpsProxy != nil { - t.Fatalf("got %s, want nil", httpProxy) - } - - if example.httpsProxy != nil { - if httpsProxy == nil { - t.Fatalf("got nil, want %s", example.httpsProxy) - } - - if got, want := example.httpsProxy.String(), httpsProxy.String(); got != want { - t.Fatalf("got %s, want %s", got, want) - } - } - }) - } -} -func urlMustParse(u string) *url.URL { - r, err := url.Parse(u) - if err != nil { - panic(fmt.Sprintf("urlMustParse: %s", err)) - } - return r -} - -type mustCloser struct { - io.Reader - closed bool -} - -func (m *mustCloser) Close() error { - m.closed = true - return nil -} - -func (m *mustCloser) check(t *testing.T) { - if !m.closed { - t.Errorf("mustCloser wrapping %#v was not closed at time of check", m.Reader) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package tshttpproxy + +import ( + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "tailscale.com/tstest" +) + +func TestSynologyProxyFromConfigCached(t *testing.T) { + req, err := http.NewRequest("GET", "http://example.org/", nil) + if err != nil { + t.Fatal(err) + } + + tstest.Replace(t, &synologyProxyConfigPath, filepath.Join(t.TempDir(), "proxy.conf")) + + t.Run("no config file", func(t *testing.T) { + if _, err := os.Stat(synologyProxyConfigPath); err == nil { + t.Fatalf("%s must not exist for this test", synologyProxyConfigPath) + } + + cache.updated = time.Time{} + cache.httpProxy = nil + cache.httpsProxy = nil + + if val, err := synologyProxyFromConfigCached(req); val != nil || err != nil { + t.Fatalf("got %s, %v; want nil, nil", val, err) + } + + if got, want := cache.updated, time.Unix(0, 0); got != want { + t.Fatalf("got %s, want %s", got, want) + } + if cache.httpProxy != nil { + t.Fatalf("got %s, want nil", cache.httpProxy) + } + if cache.httpsProxy != nil { + t.Fatalf("got %s, want nil", cache.httpsProxy) + } + }) + + t.Run("config file updated", func(t *testing.T) { + cache.updated = time.Now() + cache.httpProxy = nil + cache.httpsProxy = nil + + if err := os.WriteFile(synologyProxyConfigPath, []byte(` +proxy_enabled=yes +http_host=10.0.0.55 +http_port=80 +https_host=10.0.0.66 +https_port=443 + `), 0600); err != nil { + t.Fatal(err) + } + + val, err := synologyProxyFromConfigCached(req) + if err != nil { + t.Fatal(err) + } + + if cache.httpProxy == nil { + t.Fatal("http proxy was not cached") + } + if cache.httpsProxy == nil { + t.Fatal("https proxy was not cached") + } + + if want := urlMustParse("http://10.0.0.55:80"); val.String() != want.String() { + t.Fatalf("got %s; want %s", val, want) + } + }) + + t.Run("config file removed", func(t *testing.T) { + cache.updated = time.Now() + cache.httpProxy = urlMustParse("http://127.0.0.1/") + cache.httpsProxy = urlMustParse("http://127.0.0.1/") + + if err := os.Remove(synologyProxyConfigPath); err != nil && !os.IsNotExist(err) { + t.Fatal(err) + } + + val, err := synologyProxyFromConfigCached(req) + if err != nil { + t.Fatal(err) + } + if val != nil { + t.Fatalf("got %s; want nil", val) + } + if cache.httpProxy != nil { + t.Fatalf("got %s, want nil", cache.httpProxy) + } + if cache.httpsProxy != nil { + t.Fatalf("got %s, want nil", cache.httpsProxy) + } + }) + + t.Run("picks proxy from request scheme", func(t *testing.T) { + cache.updated = time.Now() + cache.httpProxy = nil + cache.httpsProxy = nil + + if err := os.WriteFile(synologyProxyConfigPath, []byte(` +proxy_enabled=yes +http_host=10.0.0.55 +http_port=80 +https_host=10.0.0.66 +https_port=443 + `), 0600); err != nil { + t.Fatal(err) + } + + httpReq, err := http.NewRequest("GET", "http://example.com", nil) + if err != nil { + t.Fatal(err) + } + val, err := synologyProxyFromConfigCached(httpReq) + if err != nil { + t.Fatal(err) + } + if val == nil { + t.Fatalf("got nil, want an http URL") + } + if got, want := val.String(), "http://10.0.0.55:80"; got != want { + t.Fatalf("got %q, want %q", got, want) + } + + httpsReq, err := http.NewRequest("GET", "https://example.com", nil) + if err != nil { + t.Fatal(err) + } + val, err = synologyProxyFromConfigCached(httpsReq) + if err != nil { + t.Fatal(err) + } + if val == nil { + t.Fatalf("got nil, want an http URL") + } + if got, want := val.String(), "http://10.0.0.66:443"; got != want { + t.Fatalf("got %q, want %q", got, want) + } + }) +} + +func TestSynologyProxiesFromConfig(t *testing.T) { + var ( + openReader io.ReadCloser + openErr error + ) + tstest.Replace(t, &openSynologyProxyConf, func() (io.ReadCloser, error) { + return openReader, openErr + }) + + t.Run("with config", func(t *testing.T) { + mc := &mustCloser{Reader: strings.NewReader(` +proxy_user=foo +proxy_pwd=bar +proxy_enabled=yes +adv_enabled=yes +bypass_enabled=yes +auth_enabled=yes +https_host=10.0.0.66 +https_port=8443 +http_host=10.0.0.55 +http_port=80 + `)} + defer mc.check(t) + openReader = mc + + httpProxy, httpsProxy, err := synologyProxiesFromConfig() + + if got, want := err, openErr; got != want { + t.Fatalf("got %s, want %s", got, want) + } + + if got, want := httpsProxy, urlMustParse("http://foo:bar@10.0.0.66:8443"); got.String() != want.String() { + t.Fatalf("got %s, want %s", got, want) + } + + if got, want := err, openErr; got != want { + t.Fatalf("got %s, want %s", got, want) + } + + if got, want := httpProxy, urlMustParse("http://foo:bar@10.0.0.55:80"); got.String() != want.String() { + t.Fatalf("got %s, want %s", got, want) + } + + }) + + t.Run("nonexistent config", func(t *testing.T) { + openReader = nil + openErr = os.ErrNotExist + + httpProxy, httpsProxy, err := synologyProxiesFromConfig() + if err != nil { + t.Fatalf("expected no error, got %s", err) + } + if httpProxy != nil { + t.Fatalf("expected no url, got %s", httpProxy) + } + if httpsProxy != nil { + t.Fatalf("expected no url, got %s", httpsProxy) + } + }) + + t.Run("error opening config", func(t *testing.T) { + openReader = nil + openErr = errors.New("example error") + + httpProxy, httpsProxy, err := synologyProxiesFromConfig() + if err != openErr { + t.Fatalf("expected %s, got %s", openErr, err) + } + if httpProxy != nil { + t.Fatalf("expected no url, got %s", httpProxy) + } + if httpsProxy != nil { + t.Fatalf("expected no url, got %s", httpsProxy) + } + }) + +} + +func TestParseSynologyConfig(t *testing.T) { + cases := map[string]struct { + input string + httpProxy *url.URL + httpsProxy *url.URL + err error + }{ + "populated": { + input: ` +proxy_user=foo +proxy_pwd=bar +proxy_enabled=yes +adv_enabled=yes +bypass_enabled=yes +auth_enabled=yes +https_host=10.0.0.66 +https_port=8443 +http_host=10.0.0.55 +http_port=80 +`, + httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"), + httpsProxy: urlMustParse("http://foo:bar@10.0.0.66:8443"), + err: nil, + }, + "no-auth": { + input: ` +proxy_user=foo +proxy_pwd=bar +proxy_enabled=yes +adv_enabled=yes +bypass_enabled=yes +auth_enabled=no +https_host=10.0.0.66 +https_port=8443 +http_host=10.0.0.55 +http_port=80 +`, + httpProxy: urlMustParse("http://10.0.0.55:80"), + httpsProxy: urlMustParse("http://10.0.0.66:8443"), + err: nil, + }, + "http-only": { + input: ` +proxy_user=foo +proxy_pwd=bar +proxy_enabled=yes +adv_enabled=yes +bypass_enabled=yes +auth_enabled=yes +https_host= +https_port=8443 +http_host=10.0.0.55 +http_port=80 +`, + httpProxy: urlMustParse("http://foo:bar@10.0.0.55:80"), + httpsProxy: nil, + err: nil, + }, + "empty": { + input: ` +proxy_user= +proxy_pwd= +proxy_enabled= +adv_enabled= +bypass_enabled= +auth_enabled= +https_host= +https_port= +http_host= +http_port= +`, + httpProxy: nil, + httpsProxy: nil, + err: nil, + }, + } + + for name, example := range cases { + t.Run(name, func(t *testing.T) { + httpProxy, httpsProxy, err := parseSynologyConfig(strings.NewReader(example.input)) + if err != example.err { + t.Fatal(err) + } + if example.err != nil { + return + } + + if example.httpProxy == nil && httpProxy != nil { + t.Fatalf("got %s, want nil", httpProxy) + } + + if example.httpProxy != nil { + if httpProxy == nil { + t.Fatalf("got nil, want %s", example.httpProxy) + } + + if got, want := example.httpProxy.String(), httpProxy.String(); got != want { + t.Fatalf("got %s, want %s", got, want) + } + } + + if example.httpsProxy == nil && httpsProxy != nil { + t.Fatalf("got %s, want nil", httpProxy) + } + + if example.httpsProxy != nil { + if httpsProxy == nil { + t.Fatalf("got nil, want %s", example.httpsProxy) + } + + if got, want := example.httpsProxy.String(), httpsProxy.String(); got != want { + t.Fatalf("got %s, want %s", got, want) + } + } + }) + } +} +func urlMustParse(u string) *url.URL { + r, err := url.Parse(u) + if err != nil { + panic(fmt.Sprintf("urlMustParse: %s", err)) + } + return r +} + +type mustCloser struct { + io.Reader + closed bool +} + +func (m *mustCloser) Close() error { + m.closed = true + return nil +} + +func (m *mustCloser) check(t *testing.T) { + if !m.closed { + t.Errorf("mustCloser wrapping %#v was not closed at time of check", m.Reader) + } +} diff --git a/net/tshttpproxy/tshttpproxy_windows.go b/net/tshttpproxy/tshttpproxy_windows.go index cb6b24c83..06a1f5ae4 100644 --- a/net/tshttpproxy/tshttpproxy_windows.go +++ b/net/tshttpproxy/tshttpproxy_windows.go @@ -1,276 +1,276 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tshttpproxy - -import ( - "context" - "encoding/base64" - "fmt" - "log" - "net/http" - "net/url" - "runtime" - "strings" - "sync" - "syscall" - "time" - "unsafe" - - "github.com/alexbrainman/sspi/negotiate" - "golang.org/x/sys/windows" - "tailscale.com/hostinfo" - "tailscale.com/syncs" - "tailscale.com/types/logger" - "tailscale.com/util/clientmetric" - "tailscale.com/util/cmpver" -) - -func init() { - sysProxyFromEnv = proxyFromWinHTTPOrCache - sysAuthHeader = sysAuthHeaderWindows -} - -var cachedProxy struct { - sync.Mutex - val *url.URL -} - -// proxyErrorf is a rate-limited logger specifically for errors asking -// WinHTTP for the proxy information. We don't want to log about -// errors often, otherwise the log message itself will generate a new -// HTTP request which ultimately will call back into us to log again, -// forever. So for errors, we only log a bit. -var proxyErrorf = logger.RateLimitedFn(log.Printf, 10*time.Minute, 2 /* burst*/, 10 /* maxCache */) - -var ( - metricSuccess = clientmetric.NewCounter("winhttp_proxy_success") - metricErrDetectionFailed = clientmetric.NewCounter("winhttp_proxy_err_detection_failed") - metricErrInvalidParameters = clientmetric.NewCounter("winhttp_proxy_err_invalid_param") - metricErrDownloadScript = clientmetric.NewCounter("winhttp_proxy_err_download_script") - metricErrTimeout = clientmetric.NewCounter("winhttp_proxy_err_timeout") - metricErrOther = clientmetric.NewCounter("winhttp_proxy_err_other") -) - -func proxyFromWinHTTPOrCache(req *http.Request) (*url.URL, error) { - if req.URL == nil { - return nil, nil - } - urlStr := req.URL.String() - - ctx, cancel := context.WithTimeout(req.Context(), 5*time.Second) - defer cancel() - - type result struct { - proxy *url.URL - err error - } - resc := make(chan result, 1) - go func() { - proxy, err := proxyFromWinHTTP(ctx, urlStr) - resc <- result{proxy, err} - }() - - select { - case res := <-resc: - err := res.err - if err == nil { - metricSuccess.Add(1) - cachedProxy.Lock() - defer cachedProxy.Unlock() - if was, now := fmt.Sprint(cachedProxy.val), fmt.Sprint(res.proxy); was != now { - log.Printf("tshttpproxy: winhttp: updating cached proxy setting from %v to %v", was, now) - } - cachedProxy.val = res.proxy - return res.proxy, nil - } - - // See https://docs.microsoft.com/en-us/windows/win32/winhttp/error-messages - const ( - ERROR_WINHTTP_AUTODETECTION_FAILED = 12180 - ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT = 12167 - ) - if err == syscall.Errno(ERROR_WINHTTP_AUTODETECTION_FAILED) { - metricErrDetectionFailed.Add(1) - setNoProxyUntil(10 * time.Second) - return nil, nil - } - if err == windows.ERROR_INVALID_PARAMETER { - metricErrInvalidParameters.Add(1) - // Seen on Windows 8.1. (https://github.com/tailscale/tailscale/issues/879) - // TODO(bradfitz): figure this out. - setNoProxyUntil(time.Hour) - proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): ERROR_INVALID_PARAMETER [unexpected]", urlStr) - return nil, nil - } - proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): %v/%#v", urlStr, err, err) - if err == syscall.Errno(ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT) { - metricErrDownloadScript.Add(1) - setNoProxyUntil(10 * time.Second) - return nil, nil - } - metricErrOther.Add(1) - return nil, err - case <-ctx.Done(): - metricErrTimeout.Add(1) - cachedProxy.Lock() - defer cachedProxy.Unlock() - proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): timeout; using cached proxy %v", urlStr, cachedProxy.val) - return cachedProxy.val, nil - } -} - -func proxyFromWinHTTP(ctx context.Context, urlStr string) (proxy *url.URL, err error) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - whi, err := httpOpen() - if err != nil { - proxyErrorf("winhttp: Open: %v", err) - return nil, err - } - defer whi.Close() - - t0 := time.Now() - v, err := whi.GetProxyForURL(urlStr) - td := time.Since(t0).Round(time.Millisecond) - if err := ctx.Err(); err != nil { - log.Printf("tshttpproxy: winhttp: context canceled, ignoring GetProxyForURL(%q) after %v", urlStr, td) - return nil, err - } - if err != nil { - return nil, err - } - if v == "" { - return nil, nil - } - // Discard all but first proxy value for now. - if i := strings.Index(v, ";"); i != -1 { - v = v[:i] - } - if !strings.HasPrefix(v, "https://") { - v = "http://" + v - } - return url.Parse(v) -} - -var userAgent = windows.StringToUTF16Ptr("Tailscale") - -const ( - winHTTP_ACCESS_TYPE_DEFAULT_PROXY = 0 - winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY = 4 - winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG = 0x00000100 - winHTTP_AUTOPROXY_AUTO_DETECT = 1 - winHTTP_AUTO_DETECT_TYPE_DHCP = 0x00000001 - winHTTP_AUTO_DETECT_TYPE_DNS_A = 0x00000002 -) - -// Windows 8.1 is actually Windows 6.3 under the hood. Yay, marketing! -const win8dot1Ver = "6.3" - -// accessType is the flag we must pass to WinHttpOpen for proxy resolution -// depending on whether or not we're running Windows < 8.1 -var accessType syncs.AtomicValue[uint32] - -func getAccessFlag() uint32 { - if flag, ok := accessType.LoadOk(); ok { - return flag - } - var flag uint32 - if cmpver.Compare(hostinfo.GetOSVersion(), win8dot1Ver) < 0 { - flag = winHTTP_ACCESS_TYPE_DEFAULT_PROXY - } else { - flag = winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY - } - accessType.Store(flag) - return flag -} - -func httpOpen() (winHTTPInternet, error) { - return winHTTPOpen( - userAgent, - getAccessFlag(), - nil, /* WINHTTP_NO_PROXY_NAME */ - nil, /* WINHTTP_NO_PROXY_BYPASS */ - 0, - ) -} - -type winHTTPInternet windows.Handle - -func (hi winHTTPInternet) Close() error { - return winHTTPCloseHandle(hi) -} - -// WINHTTP_AUTOPROXY_OPTIONS -// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_autoproxy_options -type winHTTPAutoProxyOptions struct { - DwFlags uint32 - DwAutoDetectFlags uint32 - AutoConfigUrl *uint16 - _ uintptr - _ uint32 - FAutoLogonIfChallenged int32 // BOOL -} - -// WINHTTP_PROXY_INFO -// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_proxy_info -type winHTTPProxyInfo struct { - AccessType uint32 - Proxy *uint16 - ProxyBypass *uint16 -} - -type winHGlobal windows.Handle - -func globalFreeUTF16Ptr(p *uint16) error { - return globalFree((winHGlobal)(unsafe.Pointer(p))) -} - -func (pi *winHTTPProxyInfo) free() { - if pi.Proxy != nil { - globalFreeUTF16Ptr(pi.Proxy) - pi.Proxy = nil - } - if pi.ProxyBypass != nil { - globalFreeUTF16Ptr(pi.ProxyBypass) - pi.ProxyBypass = nil - } -} - -var proxyForURLOpts = &winHTTPAutoProxyOptions{ - DwFlags: winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG | winHTTP_AUTOPROXY_AUTO_DETECT, - DwAutoDetectFlags: winHTTP_AUTO_DETECT_TYPE_DHCP, // | winHTTP_AUTO_DETECT_TYPE_DNS_A, -} - -func (hi winHTTPInternet) GetProxyForURL(urlStr string) (string, error) { - var out winHTTPProxyInfo - err := winHTTPGetProxyForURL( - hi, - windows.StringToUTF16Ptr(urlStr), - proxyForURLOpts, - &out, - ) - if err != nil { - return "", err - } - defer out.free() - return windows.UTF16PtrToString(out.Proxy), nil -} - -func sysAuthHeaderWindows(u *url.URL) (string, error) { - spn := "HTTP/" + u.Hostname() - creds, err := negotiate.AcquireCurrentUserCredentials() - if err != nil { - return "", fmt.Errorf("negotiate.AcquireCurrentUserCredentials: %w", err) - } - defer creds.Release() - - secCtx, token, err := negotiate.NewClientContext(creds, spn) - if err != nil { - return "", fmt.Errorf("negotiate.NewClientContext: %w", err) - } - defer secCtx.Release() - - return "Negotiate " + base64.StdEncoding.EncodeToString(token), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tshttpproxy + +import ( + "context" + "encoding/base64" + "fmt" + "log" + "net/http" + "net/url" + "runtime" + "strings" + "sync" + "syscall" + "time" + "unsafe" + + "github.com/alexbrainman/sspi/negotiate" + "golang.org/x/sys/windows" + "tailscale.com/hostinfo" + "tailscale.com/syncs" + "tailscale.com/types/logger" + "tailscale.com/util/clientmetric" + "tailscale.com/util/cmpver" +) + +func init() { + sysProxyFromEnv = proxyFromWinHTTPOrCache + sysAuthHeader = sysAuthHeaderWindows +} + +var cachedProxy struct { + sync.Mutex + val *url.URL +} + +// proxyErrorf is a rate-limited logger specifically for errors asking +// WinHTTP for the proxy information. We don't want to log about +// errors often, otherwise the log message itself will generate a new +// HTTP request which ultimately will call back into us to log again, +// forever. So for errors, we only log a bit. +var proxyErrorf = logger.RateLimitedFn(log.Printf, 10*time.Minute, 2 /* burst*/, 10 /* maxCache */) + +var ( + metricSuccess = clientmetric.NewCounter("winhttp_proxy_success") + metricErrDetectionFailed = clientmetric.NewCounter("winhttp_proxy_err_detection_failed") + metricErrInvalidParameters = clientmetric.NewCounter("winhttp_proxy_err_invalid_param") + metricErrDownloadScript = clientmetric.NewCounter("winhttp_proxy_err_download_script") + metricErrTimeout = clientmetric.NewCounter("winhttp_proxy_err_timeout") + metricErrOther = clientmetric.NewCounter("winhttp_proxy_err_other") +) + +func proxyFromWinHTTPOrCache(req *http.Request) (*url.URL, error) { + if req.URL == nil { + return nil, nil + } + urlStr := req.URL.String() + + ctx, cancel := context.WithTimeout(req.Context(), 5*time.Second) + defer cancel() + + type result struct { + proxy *url.URL + err error + } + resc := make(chan result, 1) + go func() { + proxy, err := proxyFromWinHTTP(ctx, urlStr) + resc <- result{proxy, err} + }() + + select { + case res := <-resc: + err := res.err + if err == nil { + metricSuccess.Add(1) + cachedProxy.Lock() + defer cachedProxy.Unlock() + if was, now := fmt.Sprint(cachedProxy.val), fmt.Sprint(res.proxy); was != now { + log.Printf("tshttpproxy: winhttp: updating cached proxy setting from %v to %v", was, now) + } + cachedProxy.val = res.proxy + return res.proxy, nil + } + + // See https://docs.microsoft.com/en-us/windows/win32/winhttp/error-messages + const ( + ERROR_WINHTTP_AUTODETECTION_FAILED = 12180 + ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT = 12167 + ) + if err == syscall.Errno(ERROR_WINHTTP_AUTODETECTION_FAILED) { + metricErrDetectionFailed.Add(1) + setNoProxyUntil(10 * time.Second) + return nil, nil + } + if err == windows.ERROR_INVALID_PARAMETER { + metricErrInvalidParameters.Add(1) + // Seen on Windows 8.1. (https://github.com/tailscale/tailscale/issues/879) + // TODO(bradfitz): figure this out. + setNoProxyUntil(time.Hour) + proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): ERROR_INVALID_PARAMETER [unexpected]", urlStr) + return nil, nil + } + proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): %v/%#v", urlStr, err, err) + if err == syscall.Errno(ERROR_WINHTTP_UNABLE_TO_DOWNLOAD_SCRIPT) { + metricErrDownloadScript.Add(1) + setNoProxyUntil(10 * time.Second) + return nil, nil + } + metricErrOther.Add(1) + return nil, err + case <-ctx.Done(): + metricErrTimeout.Add(1) + cachedProxy.Lock() + defer cachedProxy.Unlock() + proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): timeout; using cached proxy %v", urlStr, cachedProxy.val) + return cachedProxy.val, nil + } +} + +func proxyFromWinHTTP(ctx context.Context, urlStr string) (proxy *url.URL, err error) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + whi, err := httpOpen() + if err != nil { + proxyErrorf("winhttp: Open: %v", err) + return nil, err + } + defer whi.Close() + + t0 := time.Now() + v, err := whi.GetProxyForURL(urlStr) + td := time.Since(t0).Round(time.Millisecond) + if err := ctx.Err(); err != nil { + log.Printf("tshttpproxy: winhttp: context canceled, ignoring GetProxyForURL(%q) after %v", urlStr, td) + return nil, err + } + if err != nil { + return nil, err + } + if v == "" { + return nil, nil + } + // Discard all but first proxy value for now. + if i := strings.Index(v, ";"); i != -1 { + v = v[:i] + } + if !strings.HasPrefix(v, "https://") { + v = "http://" + v + } + return url.Parse(v) +} + +var userAgent = windows.StringToUTF16Ptr("Tailscale") + +const ( + winHTTP_ACCESS_TYPE_DEFAULT_PROXY = 0 + winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY = 4 + winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG = 0x00000100 + winHTTP_AUTOPROXY_AUTO_DETECT = 1 + winHTTP_AUTO_DETECT_TYPE_DHCP = 0x00000001 + winHTTP_AUTO_DETECT_TYPE_DNS_A = 0x00000002 +) + +// Windows 8.1 is actually Windows 6.3 under the hood. Yay, marketing! +const win8dot1Ver = "6.3" + +// accessType is the flag we must pass to WinHttpOpen for proxy resolution +// depending on whether or not we're running Windows < 8.1 +var accessType syncs.AtomicValue[uint32] + +func getAccessFlag() uint32 { + if flag, ok := accessType.LoadOk(); ok { + return flag + } + var flag uint32 + if cmpver.Compare(hostinfo.GetOSVersion(), win8dot1Ver) < 0 { + flag = winHTTP_ACCESS_TYPE_DEFAULT_PROXY + } else { + flag = winHTTP_ACCESS_TYPE_AUTOMATIC_PROXY + } + accessType.Store(flag) + return flag +} + +func httpOpen() (winHTTPInternet, error) { + return winHTTPOpen( + userAgent, + getAccessFlag(), + nil, /* WINHTTP_NO_PROXY_NAME */ + nil, /* WINHTTP_NO_PROXY_BYPASS */ + 0, + ) +} + +type winHTTPInternet windows.Handle + +func (hi winHTTPInternet) Close() error { + return winHTTPCloseHandle(hi) +} + +// WINHTTP_AUTOPROXY_OPTIONS +// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_autoproxy_options +type winHTTPAutoProxyOptions struct { + DwFlags uint32 + DwAutoDetectFlags uint32 + AutoConfigUrl *uint16 + _ uintptr + _ uint32 + FAutoLogonIfChallenged int32 // BOOL +} + +// WINHTTP_PROXY_INFO +// https://docs.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_proxy_info +type winHTTPProxyInfo struct { + AccessType uint32 + Proxy *uint16 + ProxyBypass *uint16 +} + +type winHGlobal windows.Handle + +func globalFreeUTF16Ptr(p *uint16) error { + return globalFree((winHGlobal)(unsafe.Pointer(p))) +} + +func (pi *winHTTPProxyInfo) free() { + if pi.Proxy != nil { + globalFreeUTF16Ptr(pi.Proxy) + pi.Proxy = nil + } + if pi.ProxyBypass != nil { + globalFreeUTF16Ptr(pi.ProxyBypass) + pi.ProxyBypass = nil + } +} + +var proxyForURLOpts = &winHTTPAutoProxyOptions{ + DwFlags: winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG | winHTTP_AUTOPROXY_AUTO_DETECT, + DwAutoDetectFlags: winHTTP_AUTO_DETECT_TYPE_DHCP, // | winHTTP_AUTO_DETECT_TYPE_DNS_A, +} + +func (hi winHTTPInternet) GetProxyForURL(urlStr string) (string, error) { + var out winHTTPProxyInfo + err := winHTTPGetProxyForURL( + hi, + windows.StringToUTF16Ptr(urlStr), + proxyForURLOpts, + &out, + ) + if err != nil { + return "", err + } + defer out.free() + return windows.UTF16PtrToString(out.Proxy), nil +} + +func sysAuthHeaderWindows(u *url.URL) (string, error) { + spn := "HTTP/" + u.Hostname() + creds, err := negotiate.AcquireCurrentUserCredentials() + if err != nil { + return "", fmt.Errorf("negotiate.AcquireCurrentUserCredentials: %w", err) + } + defer creds.Release() + + secCtx, token, err := negotiate.NewClientContext(creds, spn) + if err != nil { + return "", fmt.Errorf("negotiate.NewClientContext: %w", err) + } + defer secCtx.Release() + + return "Negotiate " + base64.StdEncoding.EncodeToString(token), nil +} diff --git a/net/tstun/fake.go b/net/tstun/fake.go index a002952a3..3d86bb3df 100644 --- a/net/tstun/fake.go +++ b/net/tstun/fake.go @@ -1,58 +1,58 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "io" - "os" - - "github.com/tailscale/wireguard-go/tun" -) - -type fakeTUN struct { - evchan chan tun.Event - closechan chan struct{} -} - -// NewFake returns a tun.Device that does nothing. -func NewFake() tun.Device { - return &fakeTUN{ - evchan: make(chan tun.Event), - closechan: make(chan struct{}), - } -} - -func (t *fakeTUN) File() *os.File { - panic("fakeTUN.File() called, which makes no sense") -} - -func (t *fakeTUN) Close() error { - close(t.closechan) - close(t.evchan) - return nil -} - -func (t *fakeTUN) Read(out [][]byte, sizes []int, offset int) (int, error) { - <-t.closechan - return 0, io.EOF -} - -func (t *fakeTUN) Write(b [][]byte, n int) (int, error) { - select { - case <-t.closechan: - return 0, ErrClosed - default: - } - return 1, nil -} - -// FakeTUNName is the name of the fake TUN device. -const FakeTUNName = "FakeTUN" - -func (t *fakeTUN) Flush() error { return nil } -func (t *fakeTUN) MTU() (int, error) { return 1500, nil } -func (t *fakeTUN) Name() (string, error) { return FakeTUNName, nil } -func (t *fakeTUN) Events() <-chan tun.Event { return t.evchan } -func (t *fakeTUN) BatchSize() int { return 1 } -func (t *fakeTUN) IsFakeTun() bool { return true } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstun + +import ( + "io" + "os" + + "github.com/tailscale/wireguard-go/tun" +) + +type fakeTUN struct { + evchan chan tun.Event + closechan chan struct{} +} + +// NewFake returns a tun.Device that does nothing. +func NewFake() tun.Device { + return &fakeTUN{ + evchan: make(chan tun.Event), + closechan: make(chan struct{}), + } +} + +func (t *fakeTUN) File() *os.File { + panic("fakeTUN.File() called, which makes no sense") +} + +func (t *fakeTUN) Close() error { + close(t.closechan) + close(t.evchan) + return nil +} + +func (t *fakeTUN) Read(out [][]byte, sizes []int, offset int) (int, error) { + <-t.closechan + return 0, io.EOF +} + +func (t *fakeTUN) Write(b [][]byte, n int) (int, error) { + select { + case <-t.closechan: + return 0, ErrClosed + default: + } + return 1, nil +} + +// FakeTUNName is the name of the fake TUN device. +const FakeTUNName = "FakeTUN" + +func (t *fakeTUN) Flush() error { return nil } +func (t *fakeTUN) MTU() (int, error) { return 1500, nil } +func (t *fakeTUN) Name() (string, error) { return FakeTUNName, nil } +func (t *fakeTUN) Events() <-chan tun.Event { return t.evchan } +func (t *fakeTUN) BatchSize() int { return 1 } +func (t *fakeTUN) IsFakeTun() bool { return true } diff --git a/net/tstun/ifstatus_noop.go b/net/tstun/ifstatus_noop.go index 4d453b72c..8cf569f98 100644 --- a/net/tstun/ifstatus_noop.go +++ b/net/tstun/ifstatus_noop.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package tstun - -import ( - "time" - - "github.com/tailscale/wireguard-go/tun" - "tailscale.com/types/logger" -) - -// Dummy implementation that does nothing. -func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package tstun + +import ( + "time" + + "github.com/tailscale/wireguard-go/tun" + "tailscale.com/types/logger" +) + +// Dummy implementation that does nothing. +func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error { + return nil +} diff --git a/net/tstun/ifstatus_windows.go b/net/tstun/ifstatus_windows.go index 6c6377bb4..fd9fc2112 100644 --- a/net/tstun/ifstatus_windows.go +++ b/net/tstun/ifstatus_windows.go @@ -1,109 +1,109 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "fmt" - "sync" - "time" - - "github.com/tailscale/wireguard-go/tun" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "tailscale.com/types/logger" -) - -// ifaceWatcher waits for an interface to be up. -type ifaceWatcher struct { - logf logger.Logf - luid winipcfg.LUID - - mu sync.Mutex // guards following - done bool - sig chan bool -} - -// callback is the callback we register with Windows to call when IP interface changes. -func (iw *ifaceWatcher) callback(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { - // Probably should check only when MibParameterNotification, but just in case included MibAddInstance also. - if notificationType == winipcfg.MibParameterNotification || notificationType == winipcfg.MibAddInstance { - // Out of paranoia, start a goroutine to finish our work, to return to Windows out of this callback. - go iw.isUp() - } -} - -func (iw *ifaceWatcher) isUp() bool { - iw.mu.Lock() - defer iw.mu.Unlock() - - if iw.done { - // We already know that it's up - return true - } - - if iw.getOperStatus() != winipcfg.IfOperStatusUp { - return false - } - - iw.done = true - iw.sig <- true - return true -} - -func (iw *ifaceWatcher) getOperStatus() winipcfg.IfOperStatus { - ifc, err := iw.luid.Interface() - if err != nil { - iw.logf("iw.luid.Interface error: %v", err) - return 0 - } - return ifc.OperStatus -} - -func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error { - iw := &ifaceWatcher{ - luid: winipcfg.LUID(iface.(*tun.NativeTun).LUID()), - logf: logger.WithPrefix(logf, "waitInterfaceUp: "), - } - - // Just in case check the status first - if iw.getOperStatus() == winipcfg.IfOperStatusUp { - iw.logf("TUN interface already up; no need to wait") - return nil - } - - iw.sig = make(chan bool, 1) - cb, err := winipcfg.RegisterInterfaceChangeCallback(iw.callback) - if err != nil { - iw.logf("RegisterInterfaceChangeCallback error: %v", err) - return err - } - defer cb.Unregister() - - t0 := time.Now() - expires := t0.Add(timeout) - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - - for { - iw.logf("waiting for TUN interface to come up...") - - select { - case <-iw.sig: - iw.logf("TUN interface is up after %v", time.Since(t0)) - return nil - case <-ticker.C: - } - - if iw.isUp() { - // Very unlikely to happen - either NotifyIpInterfaceChange doesn't work - // or it came up in the same moment as tick. Indicate this in the log message. - iw.logf("TUN interface is up after %v (on poll, without notification)", time.Since(t0)) - return nil - } - - if expires.Before(time.Now()) { - iw.logf("timeout waiting %v for TUN interface to come up", timeout) - return fmt.Errorf("timeout waiting for TUN interface to come up") - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstun + +import ( + "fmt" + "sync" + "time" + + "github.com/tailscale/wireguard-go/tun" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" + "tailscale.com/types/logger" +) + +// ifaceWatcher waits for an interface to be up. +type ifaceWatcher struct { + logf logger.Logf + luid winipcfg.LUID + + mu sync.Mutex // guards following + done bool + sig chan bool +} + +// callback is the callback we register with Windows to call when IP interface changes. +func (iw *ifaceWatcher) callback(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { + // Probably should check only when MibParameterNotification, but just in case included MibAddInstance also. + if notificationType == winipcfg.MibParameterNotification || notificationType == winipcfg.MibAddInstance { + // Out of paranoia, start a goroutine to finish our work, to return to Windows out of this callback. + go iw.isUp() + } +} + +func (iw *ifaceWatcher) isUp() bool { + iw.mu.Lock() + defer iw.mu.Unlock() + + if iw.done { + // We already know that it's up + return true + } + + if iw.getOperStatus() != winipcfg.IfOperStatusUp { + return false + } + + iw.done = true + iw.sig <- true + return true +} + +func (iw *ifaceWatcher) getOperStatus() winipcfg.IfOperStatus { + ifc, err := iw.luid.Interface() + if err != nil { + iw.logf("iw.luid.Interface error: %v", err) + return 0 + } + return ifc.OperStatus +} + +func waitInterfaceUp(iface tun.Device, timeout time.Duration, logf logger.Logf) error { + iw := &ifaceWatcher{ + luid: winipcfg.LUID(iface.(*tun.NativeTun).LUID()), + logf: logger.WithPrefix(logf, "waitInterfaceUp: "), + } + + // Just in case check the status first + if iw.getOperStatus() == winipcfg.IfOperStatusUp { + iw.logf("TUN interface already up; no need to wait") + return nil + } + + iw.sig = make(chan bool, 1) + cb, err := winipcfg.RegisterInterfaceChangeCallback(iw.callback) + if err != nil { + iw.logf("RegisterInterfaceChangeCallback error: %v", err) + return err + } + defer cb.Unregister() + + t0 := time.Now() + expires := t0.Add(timeout) + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + iw.logf("waiting for TUN interface to come up...") + + select { + case <-iw.sig: + iw.logf("TUN interface is up after %v", time.Since(t0)) + return nil + case <-ticker.C: + } + + if iw.isUp() { + // Very unlikely to happen - either NotifyIpInterfaceChange doesn't work + // or it came up in the same moment as tick. Indicate this in the log message. + iw.logf("TUN interface is up after %v (on poll, without notification)", time.Since(t0)) + return nil + } + + if expires.Before(time.Now()) { + iw.logf("timeout waiting %v for TUN interface to come up", timeout) + return fmt.Errorf("timeout waiting for TUN interface to come up") + } + } +} diff --git a/net/tstun/linkattrs_linux.go b/net/tstun/linkattrs_linux.go index 7f5461109..681e79269 100644 --- a/net/tstun/linkattrs_linux.go +++ b/net/tstun/linkattrs_linux.go @@ -1,63 +1,63 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "github.com/mdlayher/genetlink" - "github.com/mdlayher/netlink" - "github.com/tailscale/wireguard-go/tun" - "golang.org/x/sys/unix" -) - -// setLinkSpeed sets the advertised link speed of the TUN interface. -func setLinkSpeed(iface tun.Device, mbps int) error { - name, err := iface.Name() - if err != nil { - return err - } - - conn, err := genetlink.Dial(&netlink.Config{Strict: true}) - if err != nil { - return err - } - - defer conn.Close() - - f, err := conn.GetFamily(unix.ETHTOOL_GENL_NAME) - if err != nil { - return err - } - - ae := netlink.NewAttributeEncoder() - ae.Nested(unix.ETHTOOL_A_LINKMODES_HEADER, func(nae *netlink.AttributeEncoder) error { - nae.String(unix.ETHTOOL_A_HEADER_DEV_NAME, name) - return nil - }) - ae.Uint32(unix.ETHTOOL_A_LINKMODES_SPEED, uint32(mbps)) - - b, err := ae.Encode() - if err != nil { - return err - } - - _, err = conn.Execute( - genetlink.Message{ - Header: genetlink.Header{ - Command: unix.ETHTOOL_MSG_LINKMODES_SET, - Version: unix.ETHTOOL_GENL_VERSION, - }, - Data: b, - }, - f.ID, - netlink.Request|netlink.Acknowledge, - ) - return err -} - -// setLinkAttrs sets up link attributes that can be queried by external tools. -// Its failure is non-fatal to interface bringup. -func setLinkAttrs(iface tun.Device) error { - // By default the link speed is 10Mbps, which is easily exceeded and causes monitoring tools to complain (#3933). - return setLinkSpeed(iface, unix.SPEED_UNKNOWN) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstun + +import ( + "github.com/mdlayher/genetlink" + "github.com/mdlayher/netlink" + "github.com/tailscale/wireguard-go/tun" + "golang.org/x/sys/unix" +) + +// setLinkSpeed sets the advertised link speed of the TUN interface. +func setLinkSpeed(iface tun.Device, mbps int) error { + name, err := iface.Name() + if err != nil { + return err + } + + conn, err := genetlink.Dial(&netlink.Config{Strict: true}) + if err != nil { + return err + } + + defer conn.Close() + + f, err := conn.GetFamily(unix.ETHTOOL_GENL_NAME) + if err != nil { + return err + } + + ae := netlink.NewAttributeEncoder() + ae.Nested(unix.ETHTOOL_A_LINKMODES_HEADER, func(nae *netlink.AttributeEncoder) error { + nae.String(unix.ETHTOOL_A_HEADER_DEV_NAME, name) + return nil + }) + ae.Uint32(unix.ETHTOOL_A_LINKMODES_SPEED, uint32(mbps)) + + b, err := ae.Encode() + if err != nil { + return err + } + + _, err = conn.Execute( + genetlink.Message{ + Header: genetlink.Header{ + Command: unix.ETHTOOL_MSG_LINKMODES_SET, + Version: unix.ETHTOOL_GENL_VERSION, + }, + Data: b, + }, + f.ID, + netlink.Request|netlink.Acknowledge, + ) + return err +} + +// setLinkAttrs sets up link attributes that can be queried by external tools. +// Its failure is non-fatal to interface bringup. +func setLinkAttrs(iface tun.Device) error { + // By default the link speed is 10Mbps, which is easily exceeded and causes monitoring tools to complain (#3933). + return setLinkSpeed(iface, unix.SPEED_UNKNOWN) +} diff --git a/net/tstun/linkattrs_notlinux.go b/net/tstun/linkattrs_notlinux.go index 45dd000b3..7a7b40fc2 100644 --- a/net/tstun/linkattrs_notlinux.go +++ b/net/tstun/linkattrs_notlinux.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux - -package tstun - -import "github.com/tailscale/wireguard-go/tun" - -func setLinkAttrs(iface tun.Device) error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package tstun + +import "github.com/tailscale/wireguard-go/tun" + +func setLinkAttrs(iface tun.Device) error { + return nil +} diff --git a/net/tstun/mtu.go b/net/tstun/mtu.go index b72a19bde..004529c20 100644 --- a/net/tstun/mtu.go +++ b/net/tstun/mtu.go @@ -1,161 +1,161 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "tailscale.com/envknob" -) - -// The MTU (Maximum Transmission Unit) of a network interface is the largest -// packet that can be sent or received through that interface, including all -// headers above the link layer (e.g. IP headers, UDP headers, Wireguard -// headers, etc.). We have to think about several different values of MTU: -// -// Wire MTU: The MTU of an interface underneath the tailscale TUN, e.g. an -// Ethernet network card will default to a 1500 byte MTU. The user may change -// this MTU at any time. -// -// TUN MTU: The current MTU of the tailscale TUN. This MTU is adjusted downward -// to make room for the wireguard/tailscale headers. For example, if the -// underlying network interface's MTU is 1500 bytes, the maximum size of a -// packet entering the tailscale TUN is 1420 bytes. The user may change this MTU -// at any time via the OS's tools (ifconfig, ip, etc.). -// -// User configured initial MTU: The MTU the tailscale TUN should be created -// with, set by the user via TS_DEBUG_MTU. It should be adjusted down from the -// underlying interface MTU by 80 bytes to make room for the wireguard -// headers. This envknob is mostly for debugging. This value is used once at TUN -// creation and ignored thereafter. -// -// User configured current MTU: The MTU set via the OS's tools (ifconfig, ip, -// etc.). This MTU can change at any time. Setting the MTU this way goes through -// the MTU() method of tailscale's TUN wrapper. -// -// Maximum probed MTU: This is the largest MTU size that we send probe packets -// for. -// -// Safe MTU: If the tailscale TUN MTU is set to this value, almost all packets -// will get to their destination. Tailscale defaults to this MTU in the absence -// of path MTU probe information or user MTU configuration. We may occasionally -// find a path that needs a smaller MTU but it is very rare. -// -// Peer MTU: This is the path MTU to a peer's current best endpoint. It defaults -// to the Safe MTU unless we have path MTU probe results that tell us otherwise. -// -// Initial MTU: This is the MTU tailscaled creates the TUN with. In order of -// priority, it is: -// -// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of 65536 -// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg -// overhead -// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU -// -// Current MTU: This the MTU of the tailscale TUN at any given moment -// after TUN creation. In order of priority, it is: -// -// 1. The MTU set by the user via the OS, if it has ever been set -// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg -// overhead -// 4. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU - -// TUNMTU is the MTU for the tailscale TUN. -type TUNMTU uint32 - -// WireMTU is the MTU for the underlying network devices. -type WireMTU uint32 - -const ( - // maxTUNMTU is the largest MTU we will consider for the Tailscale - // TUN. This is inherited from wireguard-go and can be surprisingly - // small; on Windows it is currently 2048 - 32 bytes and iOS it is 1700 - // - 32 bytes. - // TODO(val,raggi): On Windows this seems to derive from RIO driver - // constraints in Wireguard but we don't use RIO so could probably make - // this bigger. - maxTUNMTU TUNMTU = TUNMTU(MaxPacketSize) - // safeTUNMTU is the default "safe" MTU for the Tailscale TUN that we - // use in the absence of other information such as path MTU probes. - safeTUNMTU TUNMTU = 1280 -) - -// WireMTUsToProbe is a list of the on-the-wire MTUs we want to probe. Each time -// magicsock discovery begins, it will send a set of pings, one of each size -// listed below. -var WireMTUsToProbe = []WireMTU{ - WireMTU(safeTUNMTU), // Tailscale over Tailscale :) - TUNToWireMTU(safeTUNMTU), // Smallest MTU allowed for IPv6, current default - 1400, // Most common MTU minus a few bytes for tunnels - 1500, // Most common MTU - 8000, // Should fit inside all jumbo frame sizes - 9000, // Most jumbo frames are this size or larger -} - -// wgHeaderLen is the length of all the headers Wireguard adds to a packet -// in the worst case (IPv6). This constant is for use when we can't or -// shouldn't use information about the IP version of a specific packet -// (e.g., calculating the MTU for the Tailscale interface. -// -// A Wireguard header includes: -// -// - 20-byte IPv4 header or 40-byte IPv6 header -// - 8-byte UDP header -// - 4-byte type -// - 4-byte key index -// - 8-byte nonce -// - 16-byte authentication tag -const wgHeaderLen = 40 + 8 + 4 + 4 + 8 + 16 - -// TUNToWireMTU takes the MTU that the Tailscale TUN presents to the user and -// returns the on-the-wire MTU necessary to transmit the largest packet that -// will fit through the TUN, given that we have to add wireguard headers. -func TUNToWireMTU(t TUNMTU) WireMTU { - return WireMTU(t + wgHeaderLen) -} - -// WireToTUNMTU takes the MTU of an underlying network device and returns the -// largest possible MTU for a Tailscale TUN operating on top of that device, -// given that we have to add wireguard headers. -func WireToTUNMTU(w WireMTU) TUNMTU { - if w < wgHeaderLen { - return 0 - } - return TUNMTU(w - wgHeaderLen) -} - -// DefaultTUNMTU returns the MTU we use to set the Tailscale TUN -// MTU. It is also the path MTU that we default to if we have no -// information about the path to a peer. -// -// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of MaxTUNMTU -// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg overhead -// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU -func DefaultTUNMTU() TUNMTU { - if m, ok := envknob.LookupUintSized("TS_DEBUG_MTU", 10, 32); ok { - return min(TUNMTU(m), maxTUNMTU) - } - - debugPMTUD, _ := envknob.LookupBool("TS_DEBUG_ENABLE_PMTUD") - if debugPMTUD { - // TODO: While we are just probing MTU but not generating PTB, - // this has to continue to return the safe MTU. When we add the - // code to generate PTB, this will be: - // - // return WireToTUNMTU(maxProbedWireMTU) - return safeTUNMTU - } - - return safeTUNMTU -} - -// SafeWireMTU returns the wire MTU that is safe to use if we have no -// information about the path MTU to this peer. -func SafeWireMTU() WireMTU { - return TUNToWireMTU(safeTUNMTU) -} - -// DefaultWireMTU returns the default TUN MTU, adjusted for wireguard -// overhead. -func DefaultWireMTU() WireMTU { - return TUNToWireMTU(DefaultTUNMTU()) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstun + +import ( + "tailscale.com/envknob" +) + +// The MTU (Maximum Transmission Unit) of a network interface is the largest +// packet that can be sent or received through that interface, including all +// headers above the link layer (e.g. IP headers, UDP headers, Wireguard +// headers, etc.). We have to think about several different values of MTU: +// +// Wire MTU: The MTU of an interface underneath the tailscale TUN, e.g. an +// Ethernet network card will default to a 1500 byte MTU. The user may change +// this MTU at any time. +// +// TUN MTU: The current MTU of the tailscale TUN. This MTU is adjusted downward +// to make room for the wireguard/tailscale headers. For example, if the +// underlying network interface's MTU is 1500 bytes, the maximum size of a +// packet entering the tailscale TUN is 1420 bytes. The user may change this MTU +// at any time via the OS's tools (ifconfig, ip, etc.). +// +// User configured initial MTU: The MTU the tailscale TUN should be created +// with, set by the user via TS_DEBUG_MTU. It should be adjusted down from the +// underlying interface MTU by 80 bytes to make room for the wireguard +// headers. This envknob is mostly for debugging. This value is used once at TUN +// creation and ignored thereafter. +// +// User configured current MTU: The MTU set via the OS's tools (ifconfig, ip, +// etc.). This MTU can change at any time. Setting the MTU this way goes through +// the MTU() method of tailscale's TUN wrapper. +// +// Maximum probed MTU: This is the largest MTU size that we send probe packets +// for. +// +// Safe MTU: If the tailscale TUN MTU is set to this value, almost all packets +// will get to their destination. Tailscale defaults to this MTU in the absence +// of path MTU probe information or user MTU configuration. We may occasionally +// find a path that needs a smaller MTU but it is very rare. +// +// Peer MTU: This is the path MTU to a peer's current best endpoint. It defaults +// to the Safe MTU unless we have path MTU probe results that tell us otherwise. +// +// Initial MTU: This is the MTU tailscaled creates the TUN with. In order of +// priority, it is: +// +// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of 65536 +// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg +// overhead +// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU +// +// Current MTU: This the MTU of the tailscale TUN at any given moment +// after TUN creation. In order of priority, it is: +// +// 1. The MTU set by the user via the OS, if it has ever been set +// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg +// overhead +// 4. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU + +// TUNMTU is the MTU for the tailscale TUN. +type TUNMTU uint32 + +// WireMTU is the MTU for the underlying network devices. +type WireMTU uint32 + +const ( + // maxTUNMTU is the largest MTU we will consider for the Tailscale + // TUN. This is inherited from wireguard-go and can be surprisingly + // small; on Windows it is currently 2048 - 32 bytes and iOS it is 1700 + // - 32 bytes. + // TODO(val,raggi): On Windows this seems to derive from RIO driver + // constraints in Wireguard but we don't use RIO so could probably make + // this bigger. + maxTUNMTU TUNMTU = TUNMTU(MaxPacketSize) + // safeTUNMTU is the default "safe" MTU for the Tailscale TUN that we + // use in the absence of other information such as path MTU probes. + safeTUNMTU TUNMTU = 1280 +) + +// WireMTUsToProbe is a list of the on-the-wire MTUs we want to probe. Each time +// magicsock discovery begins, it will send a set of pings, one of each size +// listed below. +var WireMTUsToProbe = []WireMTU{ + WireMTU(safeTUNMTU), // Tailscale over Tailscale :) + TUNToWireMTU(safeTUNMTU), // Smallest MTU allowed for IPv6, current default + 1400, // Most common MTU minus a few bytes for tunnels + 1500, // Most common MTU + 8000, // Should fit inside all jumbo frame sizes + 9000, // Most jumbo frames are this size or larger +} + +// wgHeaderLen is the length of all the headers Wireguard adds to a packet +// in the worst case (IPv6). This constant is for use when we can't or +// shouldn't use information about the IP version of a specific packet +// (e.g., calculating the MTU for the Tailscale interface. +// +// A Wireguard header includes: +// +// - 20-byte IPv4 header or 40-byte IPv6 header +// - 8-byte UDP header +// - 4-byte type +// - 4-byte key index +// - 8-byte nonce +// - 16-byte authentication tag +const wgHeaderLen = 40 + 8 + 4 + 4 + 8 + 16 + +// TUNToWireMTU takes the MTU that the Tailscale TUN presents to the user and +// returns the on-the-wire MTU necessary to transmit the largest packet that +// will fit through the TUN, given that we have to add wireguard headers. +func TUNToWireMTU(t TUNMTU) WireMTU { + return WireMTU(t + wgHeaderLen) +} + +// WireToTUNMTU takes the MTU of an underlying network device and returns the +// largest possible MTU for a Tailscale TUN operating on top of that device, +// given that we have to add wireguard headers. +func WireToTUNMTU(w WireMTU) TUNMTU { + if w < wgHeaderLen { + return 0 + } + return TUNMTU(w - wgHeaderLen) +} + +// DefaultTUNMTU returns the MTU we use to set the Tailscale TUN +// MTU. It is also the path MTU that we default to if we have no +// information about the path to a peer. +// +// 1. If set, the value of TS_DEBUG_MTU clamped to a maximum of MaxTUNMTU +// 2. If TS_DEBUG_ENABLE_PMTUD is set, the maximum size MTU we probe, minus wg overhead +// 3. If TS_DEBUG_ENABLE_PMTUD is not set, the Safe MTU +func DefaultTUNMTU() TUNMTU { + if m, ok := envknob.LookupUintSized("TS_DEBUG_MTU", 10, 32); ok { + return min(TUNMTU(m), maxTUNMTU) + } + + debugPMTUD, _ := envknob.LookupBool("TS_DEBUG_ENABLE_PMTUD") + if debugPMTUD { + // TODO: While we are just probing MTU but not generating PTB, + // this has to continue to return the safe MTU. When we add the + // code to generate PTB, this will be: + // + // return WireToTUNMTU(maxProbedWireMTU) + return safeTUNMTU + } + + return safeTUNMTU +} + +// SafeWireMTU returns the wire MTU that is safe to use if we have no +// information about the path MTU to this peer. +func SafeWireMTU() WireMTU { + return TUNToWireMTU(safeTUNMTU) +} + +// DefaultWireMTU returns the default TUN MTU, adjusted for wireguard +// overhead. +func DefaultWireMTU() WireMTU { + return TUNToWireMTU(DefaultTUNMTU()) +} diff --git a/net/tstun/mtu_test.go b/net/tstun/mtu_test.go index fc5274ae1..8d165bfd3 100644 --- a/net/tstun/mtu_test.go +++ b/net/tstun/mtu_test.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -package tstun - -import ( - "os" - "strconv" - "testing" -) - -// Test the default MTU in the presence of various envknobs. -func TestDefaultTunMTU(t *testing.T) { - // Save and restore the envknobs we will be changing. - - // TS_DEBUG_MTU sets the MTU to a specific value. - defer os.Setenv("TS_DEBUG_MTU", os.Getenv("TS_DEBUG_MTU")) - os.Setenv("TS_DEBUG_MTU", "") - - // TS_DEBUG_ENABLE_PMTUD enables path MTU discovery. - defer os.Setenv("TS_DEBUG_ENABLE_PMTUD", os.Getenv("TS_DEBUG_ENABLE_PMTUD")) - os.Setenv("TS_DEBUG_ENABLE_PMTUD", "") - - // With no MTU envknobs set, we should get the conservative MTU. - if DefaultTUNMTU() != safeTUNMTU { - t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU) - } - - // If set, TS_DEBUG_MTU should set the MTU. - mtu := maxTUNMTU - 1 - os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) - if DefaultTUNMTU() != mtu { - t.Errorf("default TUN MTU = %d, want %d, TS_DEBUG_MTU ignored", DefaultTUNMTU(), mtu) - } - - // MTU should be clamped to maxTunMTU. - mtu = maxTUNMTU + 1 - os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) - if DefaultTUNMTU() != maxTUNMTU { - t.Errorf("default TUN MTU = %d, want %d, clamping failed", DefaultTUNMTU(), maxTUNMTU) - } - - // If PMTUD is enabled, the MTU should default to the safe MTU, but only - // if the user hasn't requested a specific MTU. - // - // TODO: When PMTUD is generating PTB responses, this will become the - // largest MTU we probe. - os.Setenv("TS_DEBUG_MTU", "") - os.Setenv("TS_DEBUG_ENABLE_PMTUD", "true") - if DefaultTUNMTU() != safeTUNMTU { - t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU) - } - // TS_DEBUG_MTU should take precedence over TS_DEBUG_ENABLE_PMTUD. - mtu = WireToTUNMTU(MaxPacketSize - 1) - os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) - if DefaultTUNMTU() != mtu { - t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), mtu) - } -} - -// Test the conversion of wire MTU to/from Tailscale TUN MTU corner cases. -func TestMTUConversion(t *testing.T) { - tests := []struct { - w WireMTU - t TUNMTU - }{ - {w: 0, t: 0}, - {w: wgHeaderLen - 1, t: 0}, - {w: wgHeaderLen, t: 0}, - {w: wgHeaderLen + 1, t: 1}, - {w: 1360, t: 1280}, - {w: 1500, t: 1420}, - {w: 9000, t: 8920}, - } - - for _, tt := range tests { - m := WireToTUNMTU(tt.w) - if m != tt.t { - t.Errorf("conversion of wire MTU %v to TUN MTU = %v, want %v", tt.w, m, tt.t) - } - } - - tests2 := []struct { - t TUNMTU - w WireMTU - }{ - {t: 0, w: wgHeaderLen}, - {t: 1, w: wgHeaderLen + 1}, - {t: 1280, w: 1360}, - {t: 1420, w: 1500}, - {t: 8920, w: 9000}, - } - - for _, tt := range tests2 { - m := TUNToWireMTU(tt.t) - if m != tt.w { - t.Errorf("conversion of TUN MTU %v to wire MTU = %v, want %v", tt.t, m, tt.w) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +package tstun + +import ( + "os" + "strconv" + "testing" +) + +// Test the default MTU in the presence of various envknobs. +func TestDefaultTunMTU(t *testing.T) { + // Save and restore the envknobs we will be changing. + + // TS_DEBUG_MTU sets the MTU to a specific value. + defer os.Setenv("TS_DEBUG_MTU", os.Getenv("TS_DEBUG_MTU")) + os.Setenv("TS_DEBUG_MTU", "") + + // TS_DEBUG_ENABLE_PMTUD enables path MTU discovery. + defer os.Setenv("TS_DEBUG_ENABLE_PMTUD", os.Getenv("TS_DEBUG_ENABLE_PMTUD")) + os.Setenv("TS_DEBUG_ENABLE_PMTUD", "") + + // With no MTU envknobs set, we should get the conservative MTU. + if DefaultTUNMTU() != safeTUNMTU { + t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU) + } + + // If set, TS_DEBUG_MTU should set the MTU. + mtu := maxTUNMTU - 1 + os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) + if DefaultTUNMTU() != mtu { + t.Errorf("default TUN MTU = %d, want %d, TS_DEBUG_MTU ignored", DefaultTUNMTU(), mtu) + } + + // MTU should be clamped to maxTunMTU. + mtu = maxTUNMTU + 1 + os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) + if DefaultTUNMTU() != maxTUNMTU { + t.Errorf("default TUN MTU = %d, want %d, clamping failed", DefaultTUNMTU(), maxTUNMTU) + } + + // If PMTUD is enabled, the MTU should default to the safe MTU, but only + // if the user hasn't requested a specific MTU. + // + // TODO: When PMTUD is generating PTB responses, this will become the + // largest MTU we probe. + os.Setenv("TS_DEBUG_MTU", "") + os.Setenv("TS_DEBUG_ENABLE_PMTUD", "true") + if DefaultTUNMTU() != safeTUNMTU { + t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), safeTUNMTU) + } + // TS_DEBUG_MTU should take precedence over TS_DEBUG_ENABLE_PMTUD. + mtu = WireToTUNMTU(MaxPacketSize - 1) + os.Setenv("TS_DEBUG_MTU", strconv.Itoa(int(mtu))) + if DefaultTUNMTU() != mtu { + t.Errorf("default TUN MTU = %d, want %d", DefaultTUNMTU(), mtu) + } +} + +// Test the conversion of wire MTU to/from Tailscale TUN MTU corner cases. +func TestMTUConversion(t *testing.T) { + tests := []struct { + w WireMTU + t TUNMTU + }{ + {w: 0, t: 0}, + {w: wgHeaderLen - 1, t: 0}, + {w: wgHeaderLen, t: 0}, + {w: wgHeaderLen + 1, t: 1}, + {w: 1360, t: 1280}, + {w: 1500, t: 1420}, + {w: 9000, t: 8920}, + } + + for _, tt := range tests { + m := WireToTUNMTU(tt.w) + if m != tt.t { + t.Errorf("conversion of wire MTU %v to TUN MTU = %v, want %v", tt.w, m, tt.t) + } + } + + tests2 := []struct { + t TUNMTU + w WireMTU + }{ + {t: 0, w: wgHeaderLen}, + {t: 1, w: wgHeaderLen + 1}, + {t: 1280, w: 1360}, + {t: 1420, w: 1500}, + {t: 8920, w: 9000}, + } + + for _, tt := range tests2 { + m := TUNToWireMTU(tt.t) + if m != tt.w { + t.Errorf("conversion of TUN MTU %v to wire MTU = %v, want %v", tt.t, m, tt.w) + } + } +} diff --git a/net/tstun/tun_linux.go b/net/tstun/tun_linux.go index e08f12bc1..9600ceb77 100644 --- a/net/tstun/tun_linux.go +++ b/net/tstun/tun_linux.go @@ -1,103 +1,103 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstun - -import ( - "bytes" - "errors" - "os" - "os/exec" - "strings" - "syscall" - - "tailscale.com/types/logger" - "tailscale.com/version/distro" -) - -func init() { - tunDiagnoseFailure = diagnoseLinuxTUNFailure -} - -func diagnoseLinuxTUNFailure(tunName string, logf logger.Logf, createErr error) { - if errors.Is(createErr, syscall.EBUSY) { - logf("TUN device %s is busy; another process probably still has it open (from old version of Tailscale that had a bug)", tunName) - logf("To fix, kill the process that has it open. Find with:\n\n$ sudo lsof -n /dev/net/tun\n\n") - logf("... and then kill those PID(s)") - return - } - - var un syscall.Utsname - err := syscall.Uname(&un) - if err != nil { - logf("no TUN, and failed to look up kernel version: %v", err) - return - } - kernel := utsReleaseField(&un) - logf("Linux kernel version: %s", kernel) - - modprobeOut, err := exec.Command("/sbin/modprobe", "tun").CombinedOutput() - if err == nil { - logf("'modprobe tun' successful") - // Either tun is currently loaded, or it's statically - // compiled into the kernel (which modprobe checks - // with /lib/modules/$(uname -r)/modules.builtin) - // - // So if there's a problem at this point, it's - // probably because /dev/net/tun doesn't exist. - const dev = "/dev/net/tun" - if fi, err := os.Stat(dev); err != nil { - logf("tun module loaded in kernel, but %s does not exist", dev) - } else { - logf("%s: %v", dev, fi.Mode()) - } - - // We failed to find why it failed. Just let our - // caller report the error it got from wireguard-go. - return - } - logf("is CONFIG_TUN enabled in your kernel? `modprobe tun` failed with: %s", modprobeOut) - - switch distro.Get() { - case distro.Debian: - dpkgOut, err := exec.Command("dpkg", "-S", "kernel/drivers/net/tun.ko").CombinedOutput() - if len(bytes.TrimSpace(dpkgOut)) == 0 || err != nil { - logf("tun module not loaded nor found on disk") - return - } - if !bytes.Contains(dpkgOut, []byte(kernel)) { - logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", dpkgOut) - } - case distro.Arch: - findOut, err := exec.Command("find", "/lib/modules/", "-path", "*/net/tun.ko*").CombinedOutput() - if len(bytes.TrimSpace(findOut)) == 0 || err != nil { - logf("tun module not loaded nor found on disk") - return - } - if !bytes.Contains(findOut, []byte(kernel)) { - logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", findOut) - } - case distro.OpenWrt: - out, err := exec.Command("opkg", "list-installed").CombinedOutput() - if err != nil { - logf("error querying OpenWrt installed packages: %s", out) - return - } - for _, pkg := range []string{"kmod-tun", "ca-bundle"} { - if !bytes.Contains(out, []byte(pkg+" - ")) { - logf("Missing required package %s; run: opkg install %s", pkg, pkg) - } - } - } -} - -func utsReleaseField(u *syscall.Utsname) string { - var sb strings.Builder - for _, v := range u.Release { - if v == 0 { - break - } - sb.WriteByte(byte(v)) - } - return strings.TrimSpace(sb.String()) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstun + +import ( + "bytes" + "errors" + "os" + "os/exec" + "strings" + "syscall" + + "tailscale.com/types/logger" + "tailscale.com/version/distro" +) + +func init() { + tunDiagnoseFailure = diagnoseLinuxTUNFailure +} + +func diagnoseLinuxTUNFailure(tunName string, logf logger.Logf, createErr error) { + if errors.Is(createErr, syscall.EBUSY) { + logf("TUN device %s is busy; another process probably still has it open (from old version of Tailscale that had a bug)", tunName) + logf("To fix, kill the process that has it open. Find with:\n\n$ sudo lsof -n /dev/net/tun\n\n") + logf("... and then kill those PID(s)") + return + } + + var un syscall.Utsname + err := syscall.Uname(&un) + if err != nil { + logf("no TUN, and failed to look up kernel version: %v", err) + return + } + kernel := utsReleaseField(&un) + logf("Linux kernel version: %s", kernel) + + modprobeOut, err := exec.Command("/sbin/modprobe", "tun").CombinedOutput() + if err == nil { + logf("'modprobe tun' successful") + // Either tun is currently loaded, or it's statically + // compiled into the kernel (which modprobe checks + // with /lib/modules/$(uname -r)/modules.builtin) + // + // So if there's a problem at this point, it's + // probably because /dev/net/tun doesn't exist. + const dev = "/dev/net/tun" + if fi, err := os.Stat(dev); err != nil { + logf("tun module loaded in kernel, but %s does not exist", dev) + } else { + logf("%s: %v", dev, fi.Mode()) + } + + // We failed to find why it failed. Just let our + // caller report the error it got from wireguard-go. + return + } + logf("is CONFIG_TUN enabled in your kernel? `modprobe tun` failed with: %s", modprobeOut) + + switch distro.Get() { + case distro.Debian: + dpkgOut, err := exec.Command("dpkg", "-S", "kernel/drivers/net/tun.ko").CombinedOutput() + if len(bytes.TrimSpace(dpkgOut)) == 0 || err != nil { + logf("tun module not loaded nor found on disk") + return + } + if !bytes.Contains(dpkgOut, []byte(kernel)) { + logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", dpkgOut) + } + case distro.Arch: + findOut, err := exec.Command("find", "/lib/modules/", "-path", "*/net/tun.ko*").CombinedOutput() + if len(bytes.TrimSpace(findOut)) == 0 || err != nil { + logf("tun module not loaded nor found on disk") + return + } + if !bytes.Contains(findOut, []byte(kernel)) { + logf("kernel/drivers/net/tun.ko found on disk, but not for current kernel; are you in middle of a system update and haven't rebooted? found: %s", findOut) + } + case distro.OpenWrt: + out, err := exec.Command("opkg", "list-installed").CombinedOutput() + if err != nil { + logf("error querying OpenWrt installed packages: %s", out) + return + } + for _, pkg := range []string{"kmod-tun", "ca-bundle"} { + if !bytes.Contains(out, []byte(pkg+" - ")) { + logf("Missing required package %s; run: opkg install %s", pkg, pkg) + } + } + } +} + +func utsReleaseField(u *syscall.Utsname) string { + var sb strings.Builder + for _, v := range u.Release { + if v == 0 { + break + } + sb.WriteByte(byte(v)) + } + return strings.TrimSpace(sb.String()) +} diff --git a/net/tstun/tun_macos.go b/net/tstun/tun_macos.go index f71494f0b..3506f05b1 100644 --- a/net/tstun/tun_macos.go +++ b/net/tstun/tun_macos.go @@ -1,25 +1,25 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin && !ios - -package tstun - -import ( - "os" - - "tailscale.com/types/logger" -) - -func init() { - tunDiagnoseFailure = diagnoseDarwinTUNFailure -} - -func diagnoseDarwinTUNFailure(tunName string, logf logger.Logf, err error) { - if os.Getuid() != 0 { - logf("failed to create TUN device as non-root user; use 'sudo tailscaled', or run under launchd with 'sudo tailscaled install-system-daemon'") - } - if tunName != "utun" { - logf("failed to create TUN device %q; try using tun device \"utun\" instead for automatic selection", tunName) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin && !ios + +package tstun + +import ( + "os" + + "tailscale.com/types/logger" +) + +func init() { + tunDiagnoseFailure = diagnoseDarwinTUNFailure +} + +func diagnoseDarwinTUNFailure(tunName string, logf logger.Logf, err error) { + if os.Getuid() != 0 { + logf("failed to create TUN device as non-root user; use 'sudo tailscaled', or run under launchd with 'sudo tailscaled install-system-daemon'") + } + if tunName != "utun" { + logf("failed to create TUN device %q; try using tun device \"utun\" instead for automatic selection", tunName) + } +} diff --git a/net/tstun/tun_notwindows.go b/net/tstun/tun_notwindows.go index 60f1c62ba..087fcd4ee 100644 --- a/net/tstun/tun_notwindows.go +++ b/net/tstun/tun_notwindows.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package tstun - -import "github.com/tailscale/wireguard-go/tun" - -func interfaceName(dev tun.Device) (string, error) { - return dev.Name() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package tstun + +import "github.com/tailscale/wireguard-go/tun" + +func interfaceName(dev tun.Device) (string, error) { + return dev.Name() +} diff --git a/packages/deb/deb.go b/packages/deb/deb.go index 1be7f9652..30e3f2b4d 100644 --- a/packages/deb/deb.go +++ b/packages/deb/deb.go @@ -1,182 +1,182 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package deb extracts metadata from Debian packages. -package deb - -import ( - "archive/tar" - "bufio" - "bytes" - "compress/gzip" - "crypto/md5" - "crypto/sha1" - "crypto/sha256" - "errors" - "fmt" - "io" - "os" - "path/filepath" - "strconv" - "strings" -) - -// Info is the Debian package metadata needed to integrate the package -// into a repository. -type Info struct { - // Version is the version of the package, as reported by dpkg. - Version string - // Arch is the Debian CPU architecture the package is for. - Arch string - // Control is the entire contents of the package's control file, - // with leading and trailing whitespace removed. - Control []byte - // MD5 is the MD5 hash of the package file. - MD5 []byte - // SHA1 is the SHA1 hash of the package file. - SHA1 []byte - // SHA256 is the SHA256 hash of the package file. - SHA256 []byte -} - -// ReadFile returns Debian package metadata from the .deb file at path. -func ReadFile(path string) (*Info, error) { - f, err := os.Open(path) - if err != nil { - return nil, err - } - return Read(f) -} - -// Read returns Debian package metadata from the .deb file in r. -func Read(r io.Reader) (*Info, error) { - b := bufio.NewReader(r) - - m5, s1, s256 := md5.New(), sha1.New(), sha256.New() - summers := io.MultiWriter(m5, s1, s256) - r = io.TeeReader(b, summers) - - t, err := findControlTar(r) - if err != nil { - return nil, fmt.Errorf("searching for control.tar.gz: %w", err) - } - - control, err := findControlFile(t) - if err != nil { - return nil, fmt.Errorf("searching for control file in control.tar.gz: %w", err) - } - - arch, version, err := findArchAndVersion(control) - if err != nil { - return nil, fmt.Errorf("extracting version and architecture from control file: %w", err) - } - - // Exhaust the remainder of r, so that the summers see the entire file. - if _, err := io.Copy(io.Discard, r); err != nil { - return nil, fmt.Errorf("hashing file: %w", err) - } - - return &Info{ - Version: version, - Arch: arch, - Control: control, - MD5: m5.Sum(nil), - SHA1: s1.Sum(nil), - SHA256: s256.Sum(nil), - }, nil -} - -// findControlTar reads r as an `ar` archive, finds a tarball named -// `control.tar.gz` within, and returns a reader for that file. -func findControlTar(r io.Reader) (tarReader io.Reader, err error) { - var magic [8]byte - if _, err := io.ReadFull(r, magic[:]); err != nil { - return nil, fmt.Errorf("reading ar magic: %w", err) - } - if string(magic[:]) != "!\n" { - return nil, fmt.Errorf("not an ar file (bad magic %q)", magic) - } - - for { - var hdr [60]byte - if _, err := io.ReadFull(r, hdr[:]); err != nil { - return nil, fmt.Errorf("reading file header: %w", err) - } - filename := strings.TrimSpace(string(hdr[:16])) - size, err := strconv.ParseInt(strings.TrimSpace(string(hdr[48:58])), 10, 64) - if err != nil { - return nil, fmt.Errorf("reading size of file %q: %w", filename, err) - } - if filename == "control.tar.gz" { - return io.LimitReader(r, size), nil - } - - // files in ar are padded out to 2 bytes. - if size%2 == 1 { - size++ - } - if _, err := io.CopyN(io.Discard, r, size); err != nil { - return nil, fmt.Errorf("seeking past file %q: %w", filename, err) - } - } -} - -// findControlFile reads r as a tar.gz archive, finds a file named -// `control` within, and returns its contents. -func findControlFile(r io.Reader) (control []byte, err error) { - gz, err := gzip.NewReader(r) - if err != nil { - return nil, fmt.Errorf("decompressing control.tar.gz: %w", err) - } - defer gz.Close() - - tr := tar.NewReader(gz) - for { - hdr, err := tr.Next() - if err != nil { - if errors.Is(err, io.EOF) { - return nil, errors.New("EOF while looking for control file in control.tar.gz") - } - return nil, fmt.Errorf("reading tar header: %w", err) - } - - if filepath.Clean(hdr.Name) != "control" { - continue - } - - // Found control file - break - } - - bs, err := io.ReadAll(tr) - if err != nil { - return nil, fmt.Errorf("reading control file: %w", err) - } - - return bytes.TrimSpace(bs), nil -} - -var ( - archKey = []byte("Architecture:") - versionKey = []byte("Version:") -) - -// findArchAndVersion extracts the architecture and version strings -// from the given control file. -func findArchAndVersion(control []byte) (arch string, version string, err error) { - b := bytes.NewBuffer(control) - for { - l, err := b.ReadBytes('\n') - if err != nil { - return "", "", err - } - if bytes.HasPrefix(l, archKey) { - arch = string(bytes.TrimSpace(l[len(archKey):])) - } else if bytes.HasPrefix(l, versionKey) { - version = string(bytes.TrimSpace(l[len(versionKey):])) - } - if arch != "" && version != "" { - return arch, version, nil - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package deb extracts metadata from Debian packages. +package deb + +import ( + "archive/tar" + "bufio" + "bytes" + "compress/gzip" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strconv" + "strings" +) + +// Info is the Debian package metadata needed to integrate the package +// into a repository. +type Info struct { + // Version is the version of the package, as reported by dpkg. + Version string + // Arch is the Debian CPU architecture the package is for. + Arch string + // Control is the entire contents of the package's control file, + // with leading and trailing whitespace removed. + Control []byte + // MD5 is the MD5 hash of the package file. + MD5 []byte + // SHA1 is the SHA1 hash of the package file. + SHA1 []byte + // SHA256 is the SHA256 hash of the package file. + SHA256 []byte +} + +// ReadFile returns Debian package metadata from the .deb file at path. +func ReadFile(path string) (*Info, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + return Read(f) +} + +// Read returns Debian package metadata from the .deb file in r. +func Read(r io.Reader) (*Info, error) { + b := bufio.NewReader(r) + + m5, s1, s256 := md5.New(), sha1.New(), sha256.New() + summers := io.MultiWriter(m5, s1, s256) + r = io.TeeReader(b, summers) + + t, err := findControlTar(r) + if err != nil { + return nil, fmt.Errorf("searching for control.tar.gz: %w", err) + } + + control, err := findControlFile(t) + if err != nil { + return nil, fmt.Errorf("searching for control file in control.tar.gz: %w", err) + } + + arch, version, err := findArchAndVersion(control) + if err != nil { + return nil, fmt.Errorf("extracting version and architecture from control file: %w", err) + } + + // Exhaust the remainder of r, so that the summers see the entire file. + if _, err := io.Copy(io.Discard, r); err != nil { + return nil, fmt.Errorf("hashing file: %w", err) + } + + return &Info{ + Version: version, + Arch: arch, + Control: control, + MD5: m5.Sum(nil), + SHA1: s1.Sum(nil), + SHA256: s256.Sum(nil), + }, nil +} + +// findControlTar reads r as an `ar` archive, finds a tarball named +// `control.tar.gz` within, and returns a reader for that file. +func findControlTar(r io.Reader) (tarReader io.Reader, err error) { + var magic [8]byte + if _, err := io.ReadFull(r, magic[:]); err != nil { + return nil, fmt.Errorf("reading ar magic: %w", err) + } + if string(magic[:]) != "!\n" { + return nil, fmt.Errorf("not an ar file (bad magic %q)", magic) + } + + for { + var hdr [60]byte + if _, err := io.ReadFull(r, hdr[:]); err != nil { + return nil, fmt.Errorf("reading file header: %w", err) + } + filename := strings.TrimSpace(string(hdr[:16])) + size, err := strconv.ParseInt(strings.TrimSpace(string(hdr[48:58])), 10, 64) + if err != nil { + return nil, fmt.Errorf("reading size of file %q: %w", filename, err) + } + if filename == "control.tar.gz" { + return io.LimitReader(r, size), nil + } + + // files in ar are padded out to 2 bytes. + if size%2 == 1 { + size++ + } + if _, err := io.CopyN(io.Discard, r, size); err != nil { + return nil, fmt.Errorf("seeking past file %q: %w", filename, err) + } + } +} + +// findControlFile reads r as a tar.gz archive, finds a file named +// `control` within, and returns its contents. +func findControlFile(r io.Reader) (control []byte, err error) { + gz, err := gzip.NewReader(r) + if err != nil { + return nil, fmt.Errorf("decompressing control.tar.gz: %w", err) + } + defer gz.Close() + + tr := tar.NewReader(gz) + for { + hdr, err := tr.Next() + if err != nil { + if errors.Is(err, io.EOF) { + return nil, errors.New("EOF while looking for control file in control.tar.gz") + } + return nil, fmt.Errorf("reading tar header: %w", err) + } + + if filepath.Clean(hdr.Name) != "control" { + continue + } + + // Found control file + break + } + + bs, err := io.ReadAll(tr) + if err != nil { + return nil, fmt.Errorf("reading control file: %w", err) + } + + return bytes.TrimSpace(bs), nil +} + +var ( + archKey = []byte("Architecture:") + versionKey = []byte("Version:") +) + +// findArchAndVersion extracts the architecture and version strings +// from the given control file. +func findArchAndVersion(control []byte) (arch string, version string, err error) { + b := bytes.NewBuffer(control) + for { + l, err := b.ReadBytes('\n') + if err != nil { + return "", "", err + } + if bytes.HasPrefix(l, archKey) { + arch = string(bytes.TrimSpace(l[len(archKey):])) + } else if bytes.HasPrefix(l, versionKey) { + version = string(bytes.TrimSpace(l[len(versionKey):])) + } + if arch != "" && version != "" { + return arch, version, nil + } + } +} diff --git a/packages/deb/deb_test.go b/packages/deb/deb_test.go index 0ff43da21..1a25f67ad 100644 --- a/packages/deb/deb_test.go +++ b/packages/deb/deb_test.go @@ -1,205 +1,205 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package deb - -import ( - "bytes" - "crypto/md5" - "crypto/sha1" - "crypto/sha256" - "encoding/hex" - "fmt" - "hash" - "strings" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/goreleaser/nfpm/v2" - _ "github.com/goreleaser/nfpm/v2/deb" -) - -func TestDebInfo(t *testing.T) { - tests := []struct { - name string - in []byte - want *Info - wantErr bool - }{ - { - name: "simple", - in: mkTestDeb("1.2.3", "amd64"), - want: &Info{ - Version: "1.2.3", - Arch: "amd64", - Control: mkControl( - "Package", "tailscale", - "Version", "1.2.3", - "Section", "net", - "Priority", "extra", - "Architecture", "amd64", - "Maintainer", "Tail Scalar", - "Installed-Size", "0", - "Description", "test package"), - }, - }, - { - name: "arm64", - in: mkTestDeb("1.2.3", "arm64"), - want: &Info{ - Version: "1.2.3", - Arch: "arm64", - Control: mkControl( - "Package", "tailscale", - "Version", "1.2.3", - "Section", "net", - "Priority", "extra", - "Architecture", "arm64", - "Maintainer", "Tail Scalar", - "Installed-Size", "0", - "Description", "test package"), - }, - }, - { - name: "unstable", - in: mkTestDeb("1.7.25", "amd64"), - want: &Info{ - Version: "1.7.25", - Arch: "amd64", - Control: mkControl( - "Package", "tailscale", - "Version", "1.7.25", - "Section", "net", - "Priority", "extra", - "Architecture", "amd64", - "Maintainer", "Tail Scalar", - "Installed-Size", "0", - "Description", "test package"), - }, - }, - - // These truncation tests assume the structure of a .deb - // package, which is as follows: - // magic: 8 bytes - // file header: 60 bytes, before each file blob - // - // The first file in a .deb ar is "debian-binary", which is 4 - // bytes long and consists of "2.0\n". - // The second file is control.tar.gz, which is what we care - // about introspecting for metadata. - // The final file is data.tar.gz, which we don't care about. - // - // The first file in control.tar.gz is the "control" file we - // want to read for metadata. - { - name: "truncated_ar_magic", - in: mkTestDeb("1.7.25", "amd64")[:4], - wantErr: true, - }, - { - name: "truncated_ar_header", - in: mkTestDeb("1.7.25", "amd64")[:30], - wantErr: true, - }, - { - name: "missing_control_tgz", - // Truncate right after the "debian-binary" file, which - // makes the file a valid 1-file archive that's missing - // control.tar.gz. - in: mkTestDeb("1.7.25", "amd64")[:72], - wantErr: true, - }, - { - name: "truncated_tgz", - in: mkTestDeb("1.7.25", "amd64")[:172], - wantErr: true, - }, - } - - for _, test := range tests { - // mkTestDeb returns non-deterministic output due to - // timestamps embedded in the package file, so compute the - // wanted hashes on the fly here. - if test.want != nil { - test.want.MD5 = mkHash(test.in, md5.New) - test.want.SHA1 = mkHash(test.in, sha1.New) - test.want.SHA256 = mkHash(test.in, sha256.New) - } - - t.Run(test.name, func(t *testing.T) { - b := bytes.NewBuffer(test.in) - got, err := Read(b) - if err != nil { - if test.wantErr { - t.Logf("got expected error: %v", err) - return - } - t.Fatalf("reading deb info: %v", err) - } - if diff := diff(got, test.want); diff != "" { - t.Fatalf("parsed info diff (-got+want):\n%s", diff) - } - }) - } -} - -func diff(got, want any) string { - matchField := func(name string) func(p cmp.Path) bool { - return func(p cmp.Path) bool { - if len(p) != 3 { - return false - } - return p[2].String() == "."+name - } - } - toLines := cmp.Transformer("lines", func(b []byte) []string { return strings.Split(string(b), "\n") }) - toHex := cmp.Transformer("hex", func(b []byte) string { return hex.EncodeToString(b) }) - return cmp.Diff(got, want, - cmp.FilterPath(matchField("Control"), toLines), - cmp.FilterPath(matchField("MD5"), toHex), - cmp.FilterPath(matchField("SHA1"), toHex), - cmp.FilterPath(matchField("SHA256"), toHex)) -} - -func mkTestDeb(version, arch string) []byte { - info := nfpm.WithDefaults(&nfpm.Info{ - Name: "tailscale", - Description: "test package", - Arch: arch, - Platform: "linux", - Version: version, - Section: "net", - Priority: "extra", - Maintainer: "Tail Scalar", - }) - - pkg, err := nfpm.Get("deb") - if err != nil { - panic(fmt.Sprintf("getting deb packager: %v", err)) - } - - var b bytes.Buffer - if err := pkg.Package(info, &b); err != nil { - panic(fmt.Sprintf("creating deb package: %v", err)) - } - - return b.Bytes() -} - -func mkControl(fs ...string) []byte { - if len(fs)%2 != 0 { - panic("odd number of control file fields") - } - var b bytes.Buffer - for i := 0; i < len(fs); i = i + 2 { - k, v := fs[i], fs[i+1] - fmt.Fprintf(&b, "%s: %s\n", k, v) - } - return bytes.TrimSpace(b.Bytes()) -} - -func mkHash(b []byte, hasher func() hash.Hash) []byte { - h := hasher() - h.Write(b) - return h.Sum(nil) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package deb + +import ( + "bytes" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "encoding/hex" + "fmt" + "hash" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/goreleaser/nfpm/v2" + _ "github.com/goreleaser/nfpm/v2/deb" +) + +func TestDebInfo(t *testing.T) { + tests := []struct { + name string + in []byte + want *Info + wantErr bool + }{ + { + name: "simple", + in: mkTestDeb("1.2.3", "amd64"), + want: &Info{ + Version: "1.2.3", + Arch: "amd64", + Control: mkControl( + "Package", "tailscale", + "Version", "1.2.3", + "Section", "net", + "Priority", "extra", + "Architecture", "amd64", + "Maintainer", "Tail Scalar", + "Installed-Size", "0", + "Description", "test package"), + }, + }, + { + name: "arm64", + in: mkTestDeb("1.2.3", "arm64"), + want: &Info{ + Version: "1.2.3", + Arch: "arm64", + Control: mkControl( + "Package", "tailscale", + "Version", "1.2.3", + "Section", "net", + "Priority", "extra", + "Architecture", "arm64", + "Maintainer", "Tail Scalar", + "Installed-Size", "0", + "Description", "test package"), + }, + }, + { + name: "unstable", + in: mkTestDeb("1.7.25", "amd64"), + want: &Info{ + Version: "1.7.25", + Arch: "amd64", + Control: mkControl( + "Package", "tailscale", + "Version", "1.7.25", + "Section", "net", + "Priority", "extra", + "Architecture", "amd64", + "Maintainer", "Tail Scalar", + "Installed-Size", "0", + "Description", "test package"), + }, + }, + + // These truncation tests assume the structure of a .deb + // package, which is as follows: + // magic: 8 bytes + // file header: 60 bytes, before each file blob + // + // The first file in a .deb ar is "debian-binary", which is 4 + // bytes long and consists of "2.0\n". + // The second file is control.tar.gz, which is what we care + // about introspecting for metadata. + // The final file is data.tar.gz, which we don't care about. + // + // The first file in control.tar.gz is the "control" file we + // want to read for metadata. + { + name: "truncated_ar_magic", + in: mkTestDeb("1.7.25", "amd64")[:4], + wantErr: true, + }, + { + name: "truncated_ar_header", + in: mkTestDeb("1.7.25", "amd64")[:30], + wantErr: true, + }, + { + name: "missing_control_tgz", + // Truncate right after the "debian-binary" file, which + // makes the file a valid 1-file archive that's missing + // control.tar.gz. + in: mkTestDeb("1.7.25", "amd64")[:72], + wantErr: true, + }, + { + name: "truncated_tgz", + in: mkTestDeb("1.7.25", "amd64")[:172], + wantErr: true, + }, + } + + for _, test := range tests { + // mkTestDeb returns non-deterministic output due to + // timestamps embedded in the package file, so compute the + // wanted hashes on the fly here. + if test.want != nil { + test.want.MD5 = mkHash(test.in, md5.New) + test.want.SHA1 = mkHash(test.in, sha1.New) + test.want.SHA256 = mkHash(test.in, sha256.New) + } + + t.Run(test.name, func(t *testing.T) { + b := bytes.NewBuffer(test.in) + got, err := Read(b) + if err != nil { + if test.wantErr { + t.Logf("got expected error: %v", err) + return + } + t.Fatalf("reading deb info: %v", err) + } + if diff := diff(got, test.want); diff != "" { + t.Fatalf("parsed info diff (-got+want):\n%s", diff) + } + }) + } +} + +func diff(got, want any) string { + matchField := func(name string) func(p cmp.Path) bool { + return func(p cmp.Path) bool { + if len(p) != 3 { + return false + } + return p[2].String() == "."+name + } + } + toLines := cmp.Transformer("lines", func(b []byte) []string { return strings.Split(string(b), "\n") }) + toHex := cmp.Transformer("hex", func(b []byte) string { return hex.EncodeToString(b) }) + return cmp.Diff(got, want, + cmp.FilterPath(matchField("Control"), toLines), + cmp.FilterPath(matchField("MD5"), toHex), + cmp.FilterPath(matchField("SHA1"), toHex), + cmp.FilterPath(matchField("SHA256"), toHex)) +} + +func mkTestDeb(version, arch string) []byte { + info := nfpm.WithDefaults(&nfpm.Info{ + Name: "tailscale", + Description: "test package", + Arch: arch, + Platform: "linux", + Version: version, + Section: "net", + Priority: "extra", + Maintainer: "Tail Scalar", + }) + + pkg, err := nfpm.Get("deb") + if err != nil { + panic(fmt.Sprintf("getting deb packager: %v", err)) + } + + var b bytes.Buffer + if err := pkg.Package(info, &b); err != nil { + panic(fmt.Sprintf("creating deb package: %v", err)) + } + + return b.Bytes() +} + +func mkControl(fs ...string) []byte { + if len(fs)%2 != 0 { + panic("odd number of control file fields") + } + var b bytes.Buffer + for i := 0; i < len(fs); i = i + 2 { + k, v := fs[i], fs[i+1] + fmt.Fprintf(&b, "%s: %s\n", k, v) + } + return bytes.TrimSpace(b.Bytes()) +} + +func mkHash(b []byte, hasher func() hash.Hash) []byte { + h := hasher() + h.Write(b) + return h.Sum(nil) +} diff --git a/paths/migrate.go b/paths/migrate.go index 11d90a627..3a23ecca3 100644 --- a/paths/migrate.go +++ b/paths/migrate.go @@ -1,58 +1,58 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package paths - -import ( - "os" - "path/filepath" - - "tailscale.com/types/logger" -) - -// TryConfigFileMigration carefully copies the contents of oldFile to -// newFile, returning the path which should be used to read the config. -// - if newFile already exists, don't modify it just return its path -// - if neither oldFile nor newFile exist, return newFile for a fresh -// default config to be written to. -// - if oldFile exists but copying to newFile fails, return oldFile so -// there will at least be some config to work with. -func TryConfigFileMigration(logf logger.Logf, oldFile, newFile string) string { - _, err := os.Stat(newFile) - if err == nil { - // Common case for a system which has already been migrated. - return newFile - } - if !os.IsNotExist(err) { - logf("TryConfigFileMigration failed; new file: %v", err) - return newFile - } - - contents, err := os.ReadFile(oldFile) - if err != nil { - // Common case for a new user. - return newFile - } - - if err = MkStateDir(filepath.Dir(newFile)); err != nil { - logf("TryConfigFileMigration failed; MkStateDir: %v", err) - return oldFile - } - - err = os.WriteFile(newFile, contents, 0600) - if err != nil { - removeErr := os.Remove(newFile) - if removeErr != nil { - logf("TryConfigFileMigration failed; write newFile no cleanup: %v, remove err: %v", - err, removeErr) - return oldFile - } - logf("TryConfigFileMigration failed; write newFile: %v", err) - return oldFile - } - - logf("TryConfigFileMigration: successfully migrated: from %v to %v", - oldFile, newFile) - - return newFile -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package paths + +import ( + "os" + "path/filepath" + + "tailscale.com/types/logger" +) + +// TryConfigFileMigration carefully copies the contents of oldFile to +// newFile, returning the path which should be used to read the config. +// - if newFile already exists, don't modify it just return its path +// - if neither oldFile nor newFile exist, return newFile for a fresh +// default config to be written to. +// - if oldFile exists but copying to newFile fails, return oldFile so +// there will at least be some config to work with. +func TryConfigFileMigration(logf logger.Logf, oldFile, newFile string) string { + _, err := os.Stat(newFile) + if err == nil { + // Common case for a system which has already been migrated. + return newFile + } + if !os.IsNotExist(err) { + logf("TryConfigFileMigration failed; new file: %v", err) + return newFile + } + + contents, err := os.ReadFile(oldFile) + if err != nil { + // Common case for a new user. + return newFile + } + + if err = MkStateDir(filepath.Dir(newFile)); err != nil { + logf("TryConfigFileMigration failed; MkStateDir: %v", err) + return oldFile + } + + err = os.WriteFile(newFile, contents, 0600) + if err != nil { + removeErr := os.Remove(newFile) + if removeErr != nil { + logf("TryConfigFileMigration failed; write newFile no cleanup: %v, remove err: %v", + err, removeErr) + return oldFile + } + logf("TryConfigFileMigration failed; write newFile: %v", err) + return oldFile + } + + logf("TryConfigFileMigration: successfully migrated: from %v to %v", + oldFile, newFile) + + return newFile +} diff --git a/paths/paths.go b/paths/paths.go index 8cee4cabf..28c3be02a 100644 --- a/paths/paths.go +++ b/paths/paths.go @@ -1,92 +1,92 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package paths returns platform and user-specific default paths to -// Tailscale files and directories. -package paths - -import ( - "os" - "path/filepath" - "runtime" - - "tailscale.com/syncs" - "tailscale.com/version/distro" -) - -// AppSharedDir is a string set by the iOS or Android app on start -// containing a directory we can read/write in. -var AppSharedDir syncs.AtomicValue[string] - -// DefaultTailscaledSocket returns the path to the tailscaled Unix socket -// or the empty string if there's no reasonable default. -func DefaultTailscaledSocket() string { - if runtime.GOOS == "windows" { - return `\\.\pipe\ProtectedPrefix\Administrators\Tailscale\tailscaled` - } - if runtime.GOOS == "darwin" { - return "/var/run/tailscaled.socket" - } - if runtime.GOOS == "plan9" { - return "/srv/tailscaled.sock" - } - switch distro.Get() { - case distro.Synology: - if distro.DSMVersion() == 6 { - return "/var/packages/Tailscale/etc/tailscaled.sock" - } - // DSM 7 (and higher? or failure to detect.) - return "/var/packages/Tailscale/var/tailscaled.sock" - case distro.Gokrazy: - return "/perm/tailscaled/tailscaled.sock" - case distro.QNAP: - return "/tmp/tailscale/tailscaled.sock" - } - if fi, err := os.Stat("/var/run"); err == nil && fi.IsDir() { - return "/var/run/tailscale/tailscaled.sock" - } - return "tailscaled.sock" -} - -// Overridden in init by OS-specific files. -var ( - stateFileFunc func() string - - // ensureStateDirPerms applies a restrictive ACL/chmod - // to the provided directory. - ensureStateDirPerms = func(string) error { return nil } -) - -// DefaultTailscaledStateFile returns the default path to the -// tailscaled state file, or the empty string if there's no reasonable -// default value. -func DefaultTailscaledStateFile() string { - if f := stateFileFunc; f != nil { - return f() - } - if runtime.GOOS == "windows" { - return filepath.Join(os.Getenv("ProgramData"), "Tailscale", "server-state.conf") - } - return "" -} - -// MkStateDir ensures that dirPath, the daemon's configuration directory -// containing machine keys etc, both exists and has the correct permissions. -// We want it to only be accessible to the user the daemon is running under. -func MkStateDir(dirPath string) error { - if err := os.MkdirAll(dirPath, 0700); err != nil { - return err - } - return ensureStateDirPerms(dirPath) -} - -// LegacyStateFilePath returns the legacy path to the state file when -// it was stored under the current user's %LocalAppData%. -// -// It is only called on Windows. -func LegacyStateFilePath() string { - if runtime.GOOS == "windows" { - return filepath.Join(os.Getenv("LocalAppData"), "Tailscale", "server-state.conf") - } - return "" -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package paths returns platform and user-specific default paths to +// Tailscale files and directories. +package paths + +import ( + "os" + "path/filepath" + "runtime" + + "tailscale.com/syncs" + "tailscale.com/version/distro" +) + +// AppSharedDir is a string set by the iOS or Android app on start +// containing a directory we can read/write in. +var AppSharedDir syncs.AtomicValue[string] + +// DefaultTailscaledSocket returns the path to the tailscaled Unix socket +// or the empty string if there's no reasonable default. +func DefaultTailscaledSocket() string { + if runtime.GOOS == "windows" { + return `\\.\pipe\ProtectedPrefix\Administrators\Tailscale\tailscaled` + } + if runtime.GOOS == "darwin" { + return "/var/run/tailscaled.socket" + } + if runtime.GOOS == "plan9" { + return "/srv/tailscaled.sock" + } + switch distro.Get() { + case distro.Synology: + if distro.DSMVersion() == 6 { + return "/var/packages/Tailscale/etc/tailscaled.sock" + } + // DSM 7 (and higher? or failure to detect.) + return "/var/packages/Tailscale/var/tailscaled.sock" + case distro.Gokrazy: + return "/perm/tailscaled/tailscaled.sock" + case distro.QNAP: + return "/tmp/tailscale/tailscaled.sock" + } + if fi, err := os.Stat("/var/run"); err == nil && fi.IsDir() { + return "/var/run/tailscale/tailscaled.sock" + } + return "tailscaled.sock" +} + +// Overridden in init by OS-specific files. +var ( + stateFileFunc func() string + + // ensureStateDirPerms applies a restrictive ACL/chmod + // to the provided directory. + ensureStateDirPerms = func(string) error { return nil } +) + +// DefaultTailscaledStateFile returns the default path to the +// tailscaled state file, or the empty string if there's no reasonable +// default value. +func DefaultTailscaledStateFile() string { + if f := stateFileFunc; f != nil { + return f() + } + if runtime.GOOS == "windows" { + return filepath.Join(os.Getenv("ProgramData"), "Tailscale", "server-state.conf") + } + return "" +} + +// MkStateDir ensures that dirPath, the daemon's configuration directory +// containing machine keys etc, both exists and has the correct permissions. +// We want it to only be accessible to the user the daemon is running under. +func MkStateDir(dirPath string) error { + if err := os.MkdirAll(dirPath, 0700); err != nil { + return err + } + return ensureStateDirPerms(dirPath) +} + +// LegacyStateFilePath returns the legacy path to the state file when +// it was stored under the current user's %LocalAppData%. +// +// It is only called on Windows. +func LegacyStateFilePath() string { + if runtime.GOOS == "windows" { + return filepath.Join(os.Getenv("LocalAppData"), "Tailscale", "server-state.conf") + } + return "" +} diff --git a/paths/paths_windows.go b/paths/paths_windows.go index 224981049..470540065 100644 --- a/paths/paths_windows.go +++ b/paths/paths_windows.go @@ -1,100 +1,100 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package paths - -import ( - "os" - "path/filepath" - "strings" - - "golang.org/x/sys/windows" - "tailscale.com/util/winutil" -) - -func init() { - ensureStateDirPerms = ensureStateDirPermsWindows -} - -// ensureStateDirPermsWindows applies a restrictive ACL to the directory specified by dirPath. -// It sets the following security attributes on the directory: -// Owner: The user for the current process; -// Primary Group: The primary group for the current process; -// DACL: Full control to the current user and to the Administrators group. -// -// (We include Administrators so that admin users may still access logs; -// granting access exclusively to LocalSystem would require admins to use -// special tools to access the Log directory) -// -// Inheritance: The directory does not inherit the ACL from its parent. -// -// However, any directories and/or files created within this -// directory *do* inherit the ACL that we are setting. -func ensureStateDirPermsWindows(dirPath string) error { - fi, err := os.Stat(dirPath) - if err != nil { - return err - } - if !fi.IsDir() { - return os.ErrInvalid - } - if strings.ToLower(filepath.Base(dirPath)) != "tailscale" { - return nil - } - - // We need the info for our current user as SIDs - sids, err := winutil.GetCurrentUserSIDs() - if err != nil { - return err - } - - // We also need the SID for the Administrators group so that admins may - // easily access logs. - adminGroupSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid) - if err != nil { - return err - } - - // Munge the SIDs into the format required by EXPLICIT_ACCESS. - userTrustee := windows.TRUSTEE{nil, windows.NO_MULTIPLE_TRUSTEE, - windows.TRUSTEE_IS_SID, windows.TRUSTEE_IS_USER, - windows.TrusteeValueFromSID(sids.User)} - - adminTrustee := windows.TRUSTEE{nil, windows.NO_MULTIPLE_TRUSTEE, - windows.TRUSTEE_IS_SID, windows.TRUSTEE_IS_WELL_KNOWN_GROUP, - windows.TrusteeValueFromSID(adminGroupSid)} - - // We declare our access rights via this array of EXPLICIT_ACCESS structures. - // We set full access to our user and to Administrators. - // We configure the DACL such that any files or directories created within - // dirPath will also inherit this DACL. - explicitAccess := []windows.EXPLICIT_ACCESS{ - { - windows.GENERIC_ALL, - windows.SET_ACCESS, - windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT, - userTrustee, - }, - { - windows.GENERIC_ALL, - windows.SET_ACCESS, - windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT, - adminTrustee, - }, - } - - dacl, err := windows.ACLFromEntries(explicitAccess, nil) - if err != nil { - return err - } - - // We now reset the file's owner, primary group, and DACL. - // We also must pass PROTECTED_DACL_SECURITY_INFORMATION so that our new ACL - // does not inherit any ACL entries from the parent directory. - const flags = windows.OWNER_SECURITY_INFORMATION | - windows.GROUP_SECURITY_INFORMATION | - windows.DACL_SECURITY_INFORMATION | - windows.PROTECTED_DACL_SECURITY_INFORMATION - return windows.SetNamedSecurityInfo(dirPath, windows.SE_FILE_OBJECT, flags, - sids.User, sids.PrimaryGroup, dacl, nil) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package paths + +import ( + "os" + "path/filepath" + "strings" + + "golang.org/x/sys/windows" + "tailscale.com/util/winutil" +) + +func init() { + ensureStateDirPerms = ensureStateDirPermsWindows +} + +// ensureStateDirPermsWindows applies a restrictive ACL to the directory specified by dirPath. +// It sets the following security attributes on the directory: +// Owner: The user for the current process; +// Primary Group: The primary group for the current process; +// DACL: Full control to the current user and to the Administrators group. +// +// (We include Administrators so that admin users may still access logs; +// granting access exclusively to LocalSystem would require admins to use +// special tools to access the Log directory) +// +// Inheritance: The directory does not inherit the ACL from its parent. +// +// However, any directories and/or files created within this +// directory *do* inherit the ACL that we are setting. +func ensureStateDirPermsWindows(dirPath string) error { + fi, err := os.Stat(dirPath) + if err != nil { + return err + } + if !fi.IsDir() { + return os.ErrInvalid + } + if strings.ToLower(filepath.Base(dirPath)) != "tailscale" { + return nil + } + + // We need the info for our current user as SIDs + sids, err := winutil.GetCurrentUserSIDs() + if err != nil { + return err + } + + // We also need the SID for the Administrators group so that admins may + // easily access logs. + adminGroupSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid) + if err != nil { + return err + } + + // Munge the SIDs into the format required by EXPLICIT_ACCESS. + userTrustee := windows.TRUSTEE{nil, windows.NO_MULTIPLE_TRUSTEE, + windows.TRUSTEE_IS_SID, windows.TRUSTEE_IS_USER, + windows.TrusteeValueFromSID(sids.User)} + + adminTrustee := windows.TRUSTEE{nil, windows.NO_MULTIPLE_TRUSTEE, + windows.TRUSTEE_IS_SID, windows.TRUSTEE_IS_WELL_KNOWN_GROUP, + windows.TrusteeValueFromSID(adminGroupSid)} + + // We declare our access rights via this array of EXPLICIT_ACCESS structures. + // We set full access to our user and to Administrators. + // We configure the DACL such that any files or directories created within + // dirPath will also inherit this DACL. + explicitAccess := []windows.EXPLICIT_ACCESS{ + { + windows.GENERIC_ALL, + windows.SET_ACCESS, + windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT, + userTrustee, + }, + { + windows.GENERIC_ALL, + windows.SET_ACCESS, + windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT, + adminTrustee, + }, + } + + dacl, err := windows.ACLFromEntries(explicitAccess, nil) + if err != nil { + return err + } + + // We now reset the file's owner, primary group, and DACL. + // We also must pass PROTECTED_DACL_SECURITY_INFORMATION so that our new ACL + // does not inherit any ACL entries from the parent directory. + const flags = windows.OWNER_SECURITY_INFORMATION | + windows.GROUP_SECURITY_INFORMATION | + windows.DACL_SECURITY_INFORMATION | + windows.PROTECTED_DACL_SECURITY_INFORMATION + return windows.SetNamedSecurityInfo(dirPath, windows.SE_FILE_OBJECT, flags, + sids.User, sids.PrimaryGroup, dacl, nil) +} diff --git a/portlist/clean.go b/portlist/clean.go index cad1562c3..7e137de94 100644 --- a/portlist/clean.go +++ b/portlist/clean.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package portlist - -import ( - "path/filepath" - "strings" -) - -// argvSubject takes a command and its flags, and returns the -// short/pretty name for the process. This is usually the basename of -// the binary being executed, but can sometimes vary (e.g. so that we -// don't report all Java programs as "java"). -func argvSubject(argv ...string) string { - if len(argv) == 0 { - return "" - } - ret := filepath.Base(argv[0]) - - // Handle special cases. - switch { - case ret == "mono" && len(argv) >= 2: - // .Net programs execute as `mono actualProgram.exe`. - ret = filepath.Base(argv[1]) - } - - // Handle space separated argv - ret, _, _ = strings.Cut(ret, " ") - - // Remove common noise. - ret = strings.TrimSpace(ret) - ret = strings.TrimSuffix(ret, ".exe") - - return ret -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package portlist + +import ( + "path/filepath" + "strings" +) + +// argvSubject takes a command and its flags, and returns the +// short/pretty name for the process. This is usually the basename of +// the binary being executed, but can sometimes vary (e.g. so that we +// don't report all Java programs as "java"). +func argvSubject(argv ...string) string { + if len(argv) == 0 { + return "" + } + ret := filepath.Base(argv[0]) + + // Handle special cases. + switch { + case ret == "mono" && len(argv) >= 2: + // .Net programs execute as `mono actualProgram.exe`. + ret = filepath.Base(argv[1]) + } + + // Handle space separated argv + ret, _, _ = strings.Cut(ret, " ") + + // Remove common noise. + ret = strings.TrimSpace(ret) + ret = strings.TrimSuffix(ret, ".exe") + + return ret +} diff --git a/portlist/clean_test.go b/portlist/clean_test.go index cca18ab8e..5a1e34405 100644 --- a/portlist/clean_test.go +++ b/portlist/clean_test.go @@ -1,57 +1,57 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package portlist - -import "testing" - -func TestArgvSubject(t *testing.T) { - tests := []struct { - in []string - want string - }{ - { - in: nil, - want: "", - }, - { - in: []string{"/usr/bin/sshd"}, - want: "sshd", - }, - { - in: []string{"/bin/mono"}, - want: "mono", - }, - { - in: []string{"/nix/store/x2cw2xjw98zdysf56bdlfzsr7cyxv0jf-mono-5.20.1.27/bin/mono", "/bin/exampleProgram.exe"}, - want: "exampleProgram", - }, - { - in: []string{"/bin/mono", "/sbin/exampleProgram.bin"}, - want: "exampleProgram.bin", - }, - { - in: []string{"/usr/bin/sshd_config [listener] 1 of 10-100 startups"}, - want: "sshd_config", - }, - { - in: []string{"/usr/bin/sshd [listener] 0 of 10-100 startups"}, - want: "sshd", - }, - { - in: []string{"/opt/aws/bin/eic_run_authorized_keys %u %f -o AuthorizedKeysCommandUser ec2-instance-connect [listener] 0 of 10-100 startups"}, - want: "eic_run_authorized_keys", - }, - { - in: []string{"/usr/bin/nginx worker"}, - want: "nginx", - }, - } - - for _, test := range tests { - got := argvSubject(test.in...) - if got != test.want { - t.Errorf("argvSubject(%v) = %q, want %q", test.in, got, test.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package portlist + +import "testing" + +func TestArgvSubject(t *testing.T) { + tests := []struct { + in []string + want string + }{ + { + in: nil, + want: "", + }, + { + in: []string{"/usr/bin/sshd"}, + want: "sshd", + }, + { + in: []string{"/bin/mono"}, + want: "mono", + }, + { + in: []string{"/nix/store/x2cw2xjw98zdysf56bdlfzsr7cyxv0jf-mono-5.20.1.27/bin/mono", "/bin/exampleProgram.exe"}, + want: "exampleProgram", + }, + { + in: []string{"/bin/mono", "/sbin/exampleProgram.bin"}, + want: "exampleProgram.bin", + }, + { + in: []string{"/usr/bin/sshd_config [listener] 1 of 10-100 startups"}, + want: "sshd_config", + }, + { + in: []string{"/usr/bin/sshd [listener] 0 of 10-100 startups"}, + want: "sshd", + }, + { + in: []string{"/opt/aws/bin/eic_run_authorized_keys %u %f -o AuthorizedKeysCommandUser ec2-instance-connect [listener] 0 of 10-100 startups"}, + want: "eic_run_authorized_keys", + }, + { + in: []string{"/usr/bin/nginx worker"}, + want: "nginx", + }, + } + + for _, test := range tests { + got := argvSubject(test.in...) + if got != test.want { + t.Errorf("argvSubject(%v) = %q, want %q", test.in, got, test.want) + } + } +} diff --git a/portlist/netstat_test.go b/portlist/netstat_test.go index d04b657f6..023b75b79 100644 --- a/portlist/netstat_test.go +++ b/portlist/netstat_test.go @@ -1,92 +1,92 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin && !ios - -package portlist - -import ( - "bufio" - "encoding/json" - "fmt" - "strings" - "testing" - - "go4.org/mem" -) - -func TestParsePort(t *testing.T) { - type InOut struct { - in string - expect int - } - tests := []InOut{ - {"1.2.3.4:5678", 5678}, - {"0.0.0.0.999", 999}, - {"1.2.3.4:*", 0}, - {"5.5.5.5:0", 0}, - {"[1::2]:5", 5}, - {"[1::2].5", 5}, - {"gibberish", -1}, - } - - for _, io := range tests { - got := parsePort(mem.S(io.in)) - if got != io.expect { - t.Fatalf("input:%#v expect:%v got:%v\n", io.in, io.expect, got) - } - } -} - -const netstatOutput = ` -// macOS -tcp4 0 0 *.23 *.* LISTEN -tcp6 0 0 *.24 *.* LISTEN -tcp4 0 0 *.8185 *.* LISTEN -tcp4 0 0 127.0.0.1.8186 *.* LISTEN -tcp6 0 0 ::1.8187 *.* LISTEN -tcp4 0 0 127.1.2.3.8188 *.* LISTEN - -udp6 0 0 *.106 *.* -udp4 0 0 *.104 *.* -udp46 0 0 *.146 *.* -` - -func TestParsePortsNetstat(t *testing.T) { - for _, loopBack := range [...]bool{false, true} { - t.Run(fmt.Sprintf("loopback_%v", loopBack), func(t *testing.T) { - want := List{ - {"tcp", 23, "", 0}, - {"tcp", 24, "", 0}, - {"udp", 104, "", 0}, - {"udp", 106, "", 0}, - {"udp", 146, "", 0}, - {"tcp", 8185, "", 0}, // but not 8186, 8187, 8188 on localhost, when loopback is false - } - if loopBack { - want = append(want, - Port{"tcp", 8186, "", 0}, - Port{"tcp", 8187, "", 0}, - Port{"tcp", 8188, "", 0}, - ) - } - pl, err := appendParsePortsNetstat(nil, bufio.NewReader(strings.NewReader(netstatOutput)), loopBack) - if err != nil { - t.Fatal(err) - } - pl = sortAndDedup(pl) - jgot, _ := json.MarshalIndent(pl, "", "\t") - jwant, _ := json.MarshalIndent(want, "", "\t") - if len(pl) != len(want) { - t.Fatalf("Got:\n%s\n\nWant:\n%s\n", jgot, jwant) - } - for i := range pl { - if pl[i] != want[i] { - t.Errorf("row#%d\n got: %+v\n\nwant: %+v\n", - i, pl[i], want[i]) - t.Fatalf("Got:\n%s\n\nWant:\n%s\n", jgot, jwant) - } - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin && !ios + +package portlist + +import ( + "bufio" + "encoding/json" + "fmt" + "strings" + "testing" + + "go4.org/mem" +) + +func TestParsePort(t *testing.T) { + type InOut struct { + in string + expect int + } + tests := []InOut{ + {"1.2.3.4:5678", 5678}, + {"0.0.0.0.999", 999}, + {"1.2.3.4:*", 0}, + {"5.5.5.5:0", 0}, + {"[1::2]:5", 5}, + {"[1::2].5", 5}, + {"gibberish", -1}, + } + + for _, io := range tests { + got := parsePort(mem.S(io.in)) + if got != io.expect { + t.Fatalf("input:%#v expect:%v got:%v\n", io.in, io.expect, got) + } + } +} + +const netstatOutput = ` +// macOS +tcp4 0 0 *.23 *.* LISTEN +tcp6 0 0 *.24 *.* LISTEN +tcp4 0 0 *.8185 *.* LISTEN +tcp4 0 0 127.0.0.1.8186 *.* LISTEN +tcp6 0 0 ::1.8187 *.* LISTEN +tcp4 0 0 127.1.2.3.8188 *.* LISTEN + +udp6 0 0 *.106 *.* +udp4 0 0 *.104 *.* +udp46 0 0 *.146 *.* +` + +func TestParsePortsNetstat(t *testing.T) { + for _, loopBack := range [...]bool{false, true} { + t.Run(fmt.Sprintf("loopback_%v", loopBack), func(t *testing.T) { + want := List{ + {"tcp", 23, "", 0}, + {"tcp", 24, "", 0}, + {"udp", 104, "", 0}, + {"udp", 106, "", 0}, + {"udp", 146, "", 0}, + {"tcp", 8185, "", 0}, // but not 8186, 8187, 8188 on localhost, when loopback is false + } + if loopBack { + want = append(want, + Port{"tcp", 8186, "", 0}, + Port{"tcp", 8187, "", 0}, + Port{"tcp", 8188, "", 0}, + ) + } + pl, err := appendParsePortsNetstat(nil, bufio.NewReader(strings.NewReader(netstatOutput)), loopBack) + if err != nil { + t.Fatal(err) + } + pl = sortAndDedup(pl) + jgot, _ := json.MarshalIndent(pl, "", "\t") + jwant, _ := json.MarshalIndent(want, "", "\t") + if len(pl) != len(want) { + t.Fatalf("Got:\n%s\n\nWant:\n%s\n", jgot, jwant) + } + for i := range pl { + if pl[i] != want[i] { + t.Errorf("row#%d\n got: %+v\n\nwant: %+v\n", + i, pl[i], want[i]) + t.Fatalf("Got:\n%s\n\nWant:\n%s\n", jgot, jwant) + } + } + }) + } +} diff --git a/portlist/poller.go b/portlist/poller.go index 226f3b995..423bad3be 100644 --- a/portlist/poller.go +++ b/portlist/poller.go @@ -1,122 +1,122 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// This file contains the code related to the Poller type and its methods. -// The hot loop to keep efficient is Poller.Run. - -package portlist - -import ( - "errors" - "fmt" - "runtime" - "slices" - "sync" - "time" - - "tailscale.com/envknob" -) - -var ( - newOSImpl func(includeLocalhost bool) osImpl // if non-nil, constructs a new osImpl. - pollInterval = 5 * time.Second // default; changed by some OS-specific init funcs - debugDisablePortlist = envknob.RegisterBool("TS_DEBUG_DISABLE_PORTLIST") -) - -// PollInterval is the recommended OS-specific interval -// to wait between *Poller.Poll method calls. -func PollInterval() time.Duration { - return pollInterval -} - -// Poller scans the systems for listening ports periodically and sends -// the results to C. -type Poller struct { - // IncludeLocalhost controls whether services bound to localhost are included. - // - // This field should only be changed before calling Run. - IncludeLocalhost bool - - // os, if non-nil, is an OS-specific implementation of the portlist getting - // code. When non-nil, it's responsible for getting the complete list of - // cached ports complete with the process name. That is, when set, - // addProcesses is not used. - // A nil values means we don't have code for getting the list on the current - // operating system. - os osImpl - initOnce sync.Once // guards init of os - initErr error - - // scatch is memory for Poller.getList to reuse between calls. - scratch []Port - - prev List // most recent data, not aliasing scratch -} - -// osImpl is the OS-specific implementation of getting the open listening ports. -type osImpl interface { - Close() error - - // AppendListeningPorts appends to base (which must have length 0 but - // optional capacity) the list of listening ports. The Port struct should be - // populated as completely as possible. Another pass will not add anything - // to it. - // - // The appended ports should be in a sorted (or at least stable) order so - // the caller can cheaply detect when there are no changes. - AppendListeningPorts(base []Port) ([]Port, error) -} - -func (p *Poller) setPrev(pl List) { - // Make a copy, as the pass in pl slice aliases pl.scratch and we don't want - // that to except to the caller. - p.prev = slices.Clone(pl) -} - -// init initializes the Poller by ensuring it has an underlying -// OS implementation and is not turned off by envknob. -func (p *Poller) init() { - switch { - case debugDisablePortlist(): - p.initErr = errors.New("portlist disabled by envknob") - case newOSImpl == nil: - p.initErr = errors.New("portlist poller not implemented on " + runtime.GOOS) - default: - p.os = newOSImpl(p.IncludeLocalhost) - } -} - -// Close closes the Poller. -func (p *Poller) Close() error { - if p.initErr != nil { - return p.initErr - } - if p.os == nil { - return nil - } - return p.os.Close() -} - -// Poll returns the list of listening ports, if changed from -// a previous call as indicated by the changed result. -func (p *Poller) Poll() (ports []Port, changed bool, err error) { - p.initOnce.Do(p.init) - if p.initErr != nil { - return nil, false, fmt.Errorf("error initializing poller: %w", p.initErr) - } - pl, err := p.getList() - if err != nil { - return nil, false, err - } - if pl.equal(p.prev) { - return nil, false, nil - } - p.setPrev(pl) - return p.prev, true, nil -} - -func (p *Poller) getList() (List, error) { - var err error - p.scratch, err = p.os.AppendListeningPorts(p.scratch[:0]) - return p.scratch, err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This file contains the code related to the Poller type and its methods. +// The hot loop to keep efficient is Poller.Run. + +package portlist + +import ( + "errors" + "fmt" + "runtime" + "slices" + "sync" + "time" + + "tailscale.com/envknob" +) + +var ( + newOSImpl func(includeLocalhost bool) osImpl // if non-nil, constructs a new osImpl. + pollInterval = 5 * time.Second // default; changed by some OS-specific init funcs + debugDisablePortlist = envknob.RegisterBool("TS_DEBUG_DISABLE_PORTLIST") +) + +// PollInterval is the recommended OS-specific interval +// to wait between *Poller.Poll method calls. +func PollInterval() time.Duration { + return pollInterval +} + +// Poller scans the systems for listening ports periodically and sends +// the results to C. +type Poller struct { + // IncludeLocalhost controls whether services bound to localhost are included. + // + // This field should only be changed before calling Run. + IncludeLocalhost bool + + // os, if non-nil, is an OS-specific implementation of the portlist getting + // code. When non-nil, it's responsible for getting the complete list of + // cached ports complete with the process name. That is, when set, + // addProcesses is not used. + // A nil values means we don't have code for getting the list on the current + // operating system. + os osImpl + initOnce sync.Once // guards init of os + initErr error + + // scatch is memory for Poller.getList to reuse between calls. + scratch []Port + + prev List // most recent data, not aliasing scratch +} + +// osImpl is the OS-specific implementation of getting the open listening ports. +type osImpl interface { + Close() error + + // AppendListeningPorts appends to base (which must have length 0 but + // optional capacity) the list of listening ports. The Port struct should be + // populated as completely as possible. Another pass will not add anything + // to it. + // + // The appended ports should be in a sorted (or at least stable) order so + // the caller can cheaply detect when there are no changes. + AppendListeningPorts(base []Port) ([]Port, error) +} + +func (p *Poller) setPrev(pl List) { + // Make a copy, as the pass in pl slice aliases pl.scratch and we don't want + // that to except to the caller. + p.prev = slices.Clone(pl) +} + +// init initializes the Poller by ensuring it has an underlying +// OS implementation and is not turned off by envknob. +func (p *Poller) init() { + switch { + case debugDisablePortlist(): + p.initErr = errors.New("portlist disabled by envknob") + case newOSImpl == nil: + p.initErr = errors.New("portlist poller not implemented on " + runtime.GOOS) + default: + p.os = newOSImpl(p.IncludeLocalhost) + } +} + +// Close closes the Poller. +func (p *Poller) Close() error { + if p.initErr != nil { + return p.initErr + } + if p.os == nil { + return nil + } + return p.os.Close() +} + +// Poll returns the list of listening ports, if changed from +// a previous call as indicated by the changed result. +func (p *Poller) Poll() (ports []Port, changed bool, err error) { + p.initOnce.Do(p.init) + if p.initErr != nil { + return nil, false, fmt.Errorf("error initializing poller: %w", p.initErr) + } + pl, err := p.getList() + if err != nil { + return nil, false, err + } + if pl.equal(p.prev) { + return nil, false, nil + } + p.setPrev(pl) + return p.prev, true, nil +} + +func (p *Poller) getList() (List, error) { + var err error + p.scratch, err = p.os.AppendListeningPorts(p.scratch[:0]) + return p.scratch, err +} diff --git a/portlist/portlist.go b/portlist/portlist.go index 6d24cedcc..9f7af40d0 100644 --- a/portlist/portlist.go +++ b/portlist/portlist.go @@ -1,80 +1,80 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// This file is just the types. The bulk of the code is in poller.go. - -// The portlist package contains code that checks what ports are open and -// listening on the current machine. -package portlist - -import ( - "fmt" - "sort" - "strings" -) - -// Port is a listening port on the machine. -type Port struct { - Proto string // "tcp" or "udp" - Port uint16 // port number - Process string // optional process name, if found (requires suitable permissions) - Pid int // process ID, if known (requires suitable permissions) -} - -// List is a list of Ports. -type List []Port - -func (a *Port) lessThan(b *Port) bool { - if a.Port != b.Port { - return a.Port < b.Port - } - if a.Proto != b.Proto { - return a.Proto < b.Proto - } - return a.Process < b.Process -} - -func (a *Port) equal(b *Port) bool { - return a.Port == b.Port && - a.Proto == b.Proto && - a.Process == b.Process -} - -func (a List) equal(b List) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if !a[i].equal(&b[i]) { - return false - } - } - return true -} - -func (pl List) String() string { - var sb strings.Builder - for _, v := range pl { - fmt.Fprintf(&sb, "%-3s %5d %#v\n", - v.Proto, v.Port, v.Process) - } - return strings.TrimRight(sb.String(), "\n") -} - -// sortAndDedup sorts ps in place (by Port.lessThan) and then returns -// a subset of it with duplicate (Proto, Port) removed. -func sortAndDedup(ps List) List { - sort.Slice(ps, func(i, j int) bool { - return (&ps[i]).lessThan(&ps[j]) - }) - out := ps[:0] - var last Port - for _, p := range ps { - if last.Proto == p.Proto && last.Port == p.Port { - continue - } - out = append(out, p) - last = p - } - return out -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This file is just the types. The bulk of the code is in poller.go. + +// The portlist package contains code that checks what ports are open and +// listening on the current machine. +package portlist + +import ( + "fmt" + "sort" + "strings" +) + +// Port is a listening port on the machine. +type Port struct { + Proto string // "tcp" or "udp" + Port uint16 // port number + Process string // optional process name, if found (requires suitable permissions) + Pid int // process ID, if known (requires suitable permissions) +} + +// List is a list of Ports. +type List []Port + +func (a *Port) lessThan(b *Port) bool { + if a.Port != b.Port { + return a.Port < b.Port + } + if a.Proto != b.Proto { + return a.Proto < b.Proto + } + return a.Process < b.Process +} + +func (a *Port) equal(b *Port) bool { + return a.Port == b.Port && + a.Proto == b.Proto && + a.Process == b.Process +} + +func (a List) equal(b List) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if !a[i].equal(&b[i]) { + return false + } + } + return true +} + +func (pl List) String() string { + var sb strings.Builder + for _, v := range pl { + fmt.Fprintf(&sb, "%-3s %5d %#v\n", + v.Proto, v.Port, v.Process) + } + return strings.TrimRight(sb.String(), "\n") +} + +// sortAndDedup sorts ps in place (by Port.lessThan) and then returns +// a subset of it with duplicate (Proto, Port) removed. +func sortAndDedup(ps List) List { + sort.Slice(ps, func(i, j int) bool { + return (&ps[i]).lessThan(&ps[j]) + }) + out := ps[:0] + var last Port + for _, p := range ps { + if last.Proto == p.Proto && last.Port == p.Port { + continue + } + out = append(out, p) + last = p + } + return out +} diff --git a/portlist/portlist_macos.go b/portlist/portlist_macos.go index 2f4fee351..e67b2c9b8 100644 --- a/portlist/portlist_macos.go +++ b/portlist/portlist_macos.go @@ -1,230 +1,230 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin && !ios - -package portlist - -import ( - "bufio" - "bytes" - "fmt" - "log" - "os/exec" - "strings" - "sync/atomic" - "time" - - "go4.org/mem" -) - -func init() { - newOSImpl = newMacOSImpl - - // We have to run netstat, which is a bit expensive, so don't do it too often. - pollInterval = 5 * time.Second -} - -type macOSImpl struct { - known map[protoPort]*portMeta // inode string => metadata - netstatPath string // lazily populated - - br *bufio.Reader // reused - portsBuf []Port - includeLocalhost bool -} - -type protoPort struct { - proto string - port uint16 -} - -type portMeta struct { - port Port - keep bool -} - -func newMacOSImpl(includeLocalhost bool) osImpl { - return &macOSImpl{ - known: map[protoPort]*portMeta{}, - br: bufio.NewReader(bytes.NewReader(nil)), - includeLocalhost: includeLocalhost, - } -} - -func (*macOSImpl) Close() error { return nil } - -func (im *macOSImpl) AppendListeningPorts(base []Port) ([]Port, error) { - var err error - im.portsBuf, err = im.appendListeningPortsNetstat(im.portsBuf[:0]) - if err != nil { - return nil, err - } - - for _, pm := range im.known { - pm.keep = false - } - - var needProcs bool - for _, p := range im.portsBuf { - fp := protoPort{ - proto: p.Proto, - port: p.Port, - } - if pm, ok := im.known[fp]; ok { - pm.keep = true - } else { - needProcs = true - im.known[fp] = &portMeta{ - port: p, - keep: true, - } - } - } - - ret := base - for k, m := range im.known { - if !m.keep { - delete(im.known, k) - } - } - - if needProcs { - im.addProcesses() // best effort - } - - for _, m := range im.known { - ret = append(ret, m.port) - } - return sortAndDedup(ret), nil -} - -func (im *macOSImpl) appendListeningPortsNetstat(base []Port) ([]Port, error) { - if im.netstatPath == "" { - var err error - im.netstatPath, err = exec.LookPath("netstat") - if err != nil { - return nil, fmt.Errorf("netstat: lookup: %v", err) - } - } - - cmd := exec.Command(im.netstatPath, "-na") - outPipe, err := cmd.StdoutPipe() - if err != nil { - return nil, err - } - im.br.Reset(outPipe) - - if err := cmd.Start(); err != nil { - return nil, err - } - defer cmd.Process.Wait() - defer cmd.Process.Kill() - - return appendParsePortsNetstat(base, im.br, im.includeLocalhost) -} - -var lsofFailed atomic.Bool - -// In theory, lsof could replace the function of both listPorts() and -// addProcesses(), since it provides a superset of the netstat output. -// However, "netstat -na" runs ~100x faster than lsof on my machine, so -// we should do it only if the list of open ports has actually changed. -// -// This fails in a macOS sandbox (i.e. in the Mac App Store or System -// Extension GUI build), but does at least work in the -// tailscaled-on-macos mode. -func (im *macOSImpl) addProcesses() error { - if lsofFailed.Load() { - // This previously failed in the macOS sandbox, so don't try again. - return nil - } - exe, err := exec.LookPath("lsof") - if err != nil { - return fmt.Errorf("lsof: lookup: %v", err) - } - lsofCmd := exec.Command(exe, "-F", "-n", "-P", "-O", "-S2", "-T", "-i4", "-i6") - outPipe, err := lsofCmd.StdoutPipe() - if err != nil { - return err - } - err = lsofCmd.Start() - if err != nil { - var stderr []byte - if xe, ok := err.(*exec.ExitError); ok { - stderr = xe.Stderr - } - // fails when run in a macOS sandbox, so make this non-fatal. - if lsofFailed.CompareAndSwap(false, true) { - log.Printf("portlist: can't run lsof in Mac sandbox; omitting process names from service list. Error details: %v, %s", err, bytes.TrimSpace(stderr)) - } - return nil - } - defer func() { - ps, err := lsofCmd.Process.Wait() - if err != nil || ps.ExitCode() != 0 { - log.Printf("portlist: can't run lsof in Mac sandbox; omitting process names from service list. Error: %v, exit code %d", err, ps.ExitCode()) - lsofFailed.Store(true) - } - }() - defer lsofCmd.Process.Kill() - - im.br.Reset(outPipe) - - var cmd, proto string - var pid int - for { - line, err := im.br.ReadBytes('\n') - if err != nil { - break - } - if len(line) < 1 { - continue - } - field, val := line[0], bytes.TrimSpace(line[1:]) - switch field { - case 'p': - // starting a new process - cmd = "" - proto = "" - pid = 0 - if p, err := mem.ParseInt(mem.B(val), 10, 0); err == nil { - pid = int(p) - } - case 'c': - cmd = string(val) // TODO(bradfitz): avoid garbage; cache process names between runs? - case 'P': - proto = lsofProtoLower(val) - case 'n': - if mem.Contains(mem.B(val), mem.S("->")) { - continue - } - // a listening port - port := parsePort(mem.B(val)) - if port <= 0 { - continue - } - pp := protoPort{proto, uint16(port)} - m := im.known[pp] - switch { - case m != nil: - m.port.Process = cmd - m.port.Pid = pid - default: - // ignore: processes and ports come and go - } - } - } - - return nil -} - -func lsofProtoLower(p []byte) string { - if string(p) == "TCP" { - return "tcp" - } - if string(p) == "UDP" { - return "udp" - } - return strings.ToLower(string(p)) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin && !ios + +package portlist + +import ( + "bufio" + "bytes" + "fmt" + "log" + "os/exec" + "strings" + "sync/atomic" + "time" + + "go4.org/mem" +) + +func init() { + newOSImpl = newMacOSImpl + + // We have to run netstat, which is a bit expensive, so don't do it too often. + pollInterval = 5 * time.Second +} + +type macOSImpl struct { + known map[protoPort]*portMeta // inode string => metadata + netstatPath string // lazily populated + + br *bufio.Reader // reused + portsBuf []Port + includeLocalhost bool +} + +type protoPort struct { + proto string + port uint16 +} + +type portMeta struct { + port Port + keep bool +} + +func newMacOSImpl(includeLocalhost bool) osImpl { + return &macOSImpl{ + known: map[protoPort]*portMeta{}, + br: bufio.NewReader(bytes.NewReader(nil)), + includeLocalhost: includeLocalhost, + } +} + +func (*macOSImpl) Close() error { return nil } + +func (im *macOSImpl) AppendListeningPorts(base []Port) ([]Port, error) { + var err error + im.portsBuf, err = im.appendListeningPortsNetstat(im.portsBuf[:0]) + if err != nil { + return nil, err + } + + for _, pm := range im.known { + pm.keep = false + } + + var needProcs bool + for _, p := range im.portsBuf { + fp := protoPort{ + proto: p.Proto, + port: p.Port, + } + if pm, ok := im.known[fp]; ok { + pm.keep = true + } else { + needProcs = true + im.known[fp] = &portMeta{ + port: p, + keep: true, + } + } + } + + ret := base + for k, m := range im.known { + if !m.keep { + delete(im.known, k) + } + } + + if needProcs { + im.addProcesses() // best effort + } + + for _, m := range im.known { + ret = append(ret, m.port) + } + return sortAndDedup(ret), nil +} + +func (im *macOSImpl) appendListeningPortsNetstat(base []Port) ([]Port, error) { + if im.netstatPath == "" { + var err error + im.netstatPath, err = exec.LookPath("netstat") + if err != nil { + return nil, fmt.Errorf("netstat: lookup: %v", err) + } + } + + cmd := exec.Command(im.netstatPath, "-na") + outPipe, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + im.br.Reset(outPipe) + + if err := cmd.Start(); err != nil { + return nil, err + } + defer cmd.Process.Wait() + defer cmd.Process.Kill() + + return appendParsePortsNetstat(base, im.br, im.includeLocalhost) +} + +var lsofFailed atomic.Bool + +// In theory, lsof could replace the function of both listPorts() and +// addProcesses(), since it provides a superset of the netstat output. +// However, "netstat -na" runs ~100x faster than lsof on my machine, so +// we should do it only if the list of open ports has actually changed. +// +// This fails in a macOS sandbox (i.e. in the Mac App Store or System +// Extension GUI build), but does at least work in the +// tailscaled-on-macos mode. +func (im *macOSImpl) addProcesses() error { + if lsofFailed.Load() { + // This previously failed in the macOS sandbox, so don't try again. + return nil + } + exe, err := exec.LookPath("lsof") + if err != nil { + return fmt.Errorf("lsof: lookup: %v", err) + } + lsofCmd := exec.Command(exe, "-F", "-n", "-P", "-O", "-S2", "-T", "-i4", "-i6") + outPipe, err := lsofCmd.StdoutPipe() + if err != nil { + return err + } + err = lsofCmd.Start() + if err != nil { + var stderr []byte + if xe, ok := err.(*exec.ExitError); ok { + stderr = xe.Stderr + } + // fails when run in a macOS sandbox, so make this non-fatal. + if lsofFailed.CompareAndSwap(false, true) { + log.Printf("portlist: can't run lsof in Mac sandbox; omitting process names from service list. Error details: %v, %s", err, bytes.TrimSpace(stderr)) + } + return nil + } + defer func() { + ps, err := lsofCmd.Process.Wait() + if err != nil || ps.ExitCode() != 0 { + log.Printf("portlist: can't run lsof in Mac sandbox; omitting process names from service list. Error: %v, exit code %d", err, ps.ExitCode()) + lsofFailed.Store(true) + } + }() + defer lsofCmd.Process.Kill() + + im.br.Reset(outPipe) + + var cmd, proto string + var pid int + for { + line, err := im.br.ReadBytes('\n') + if err != nil { + break + } + if len(line) < 1 { + continue + } + field, val := line[0], bytes.TrimSpace(line[1:]) + switch field { + case 'p': + // starting a new process + cmd = "" + proto = "" + pid = 0 + if p, err := mem.ParseInt(mem.B(val), 10, 0); err == nil { + pid = int(p) + } + case 'c': + cmd = string(val) // TODO(bradfitz): avoid garbage; cache process names between runs? + case 'P': + proto = lsofProtoLower(val) + case 'n': + if mem.Contains(mem.B(val), mem.S("->")) { + continue + } + // a listening port + port := parsePort(mem.B(val)) + if port <= 0 { + continue + } + pp := protoPort{proto, uint16(port)} + m := im.known[pp] + switch { + case m != nil: + m.port.Process = cmd + m.port.Pid = pid + default: + // ignore: processes and ports come and go + } + } + } + + return nil +} + +func lsofProtoLower(p []byte) string { + if string(p) == "TCP" { + return "tcp" + } + if string(p) == "UDP" { + return "udp" + } + return strings.ToLower(string(p)) +} diff --git a/portlist/portlist_windows.go b/portlist/portlist_windows.go index c164dbad7..f44997359 100644 --- a/portlist/portlist_windows.go +++ b/portlist/portlist_windows.go @@ -1,103 +1,103 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package portlist - -import ( - "time" - - "tailscale.com/net/netstat" -) - -func init() { - newOSImpl = newWindowsImpl - // The portlist poller used to fork on Windows, which is insanely expensive, - // so historically we only did this every 5 seconds on Windows. Maybe we - // could reduce it down to 1 seconds like Linux, but nobody's benchmarked as - // of 2022-11-04. - pollInterval = 5 * time.Second -} - -type famPort struct { - proto string - port uint16 - pid uint32 -} - -type windowsImpl struct { - known map[famPort]*portMeta // inode string => metadata - includeLocalhost bool -} - -type portMeta struct { - port Port - keep bool -} - -func newWindowsImpl(includeLocalhost bool) osImpl { - return &windowsImpl{ - known: map[famPort]*portMeta{}, - includeLocalhost: includeLocalhost, - } -} - -func (*windowsImpl) Close() error { return nil } - -func (im *windowsImpl) AppendListeningPorts(base []Port) ([]Port, error) { - // TODO(bradfitz): netstat.Get makes a bunch of garbage. Add an Append-style - // API to that package instead/additionally. - tab, err := netstat.Get() - if err != nil { - return nil, err - } - - for _, pm := range im.known { - pm.keep = false - } - - ret := base - for _, e := range tab.Entries { - if e.State != "LISTEN" { - continue - } - if !im.includeLocalhost && !e.Local.Addr().IsUnspecified() { - continue - } - fp := famPort{ - proto: "tcp", // TODO(bradfitz): UDP too; add to netstat - port: e.Local.Port(), - pid: uint32(e.Pid), - } - pm, ok := im.known[fp] - if ok { - pm.keep = true - continue - } - var process string - if e.OSMetadata != nil { - if module, err := e.OSMetadata.GetModule(); err == nil { - process = module - } - } - pm = &portMeta{ - keep: true, - port: Port{ - Proto: "tcp", - Port: e.Local.Port(), - Process: process, - Pid: e.Pid, - }, - } - im.known[fp] = pm - } - - for k, m := range im.known { - if !m.keep { - delete(im.known, k) - continue - } - ret = append(ret, m.port) - } - - return sortAndDedup(ret), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package portlist + +import ( + "time" + + "tailscale.com/net/netstat" +) + +func init() { + newOSImpl = newWindowsImpl + // The portlist poller used to fork on Windows, which is insanely expensive, + // so historically we only did this every 5 seconds on Windows. Maybe we + // could reduce it down to 1 seconds like Linux, but nobody's benchmarked as + // of 2022-11-04. + pollInterval = 5 * time.Second +} + +type famPort struct { + proto string + port uint16 + pid uint32 +} + +type windowsImpl struct { + known map[famPort]*portMeta // inode string => metadata + includeLocalhost bool +} + +type portMeta struct { + port Port + keep bool +} + +func newWindowsImpl(includeLocalhost bool) osImpl { + return &windowsImpl{ + known: map[famPort]*portMeta{}, + includeLocalhost: includeLocalhost, + } +} + +func (*windowsImpl) Close() error { return nil } + +func (im *windowsImpl) AppendListeningPorts(base []Port) ([]Port, error) { + // TODO(bradfitz): netstat.Get makes a bunch of garbage. Add an Append-style + // API to that package instead/additionally. + tab, err := netstat.Get() + if err != nil { + return nil, err + } + + for _, pm := range im.known { + pm.keep = false + } + + ret := base + for _, e := range tab.Entries { + if e.State != "LISTEN" { + continue + } + if !im.includeLocalhost && !e.Local.Addr().IsUnspecified() { + continue + } + fp := famPort{ + proto: "tcp", // TODO(bradfitz): UDP too; add to netstat + port: e.Local.Port(), + pid: uint32(e.Pid), + } + pm, ok := im.known[fp] + if ok { + pm.keep = true + continue + } + var process string + if e.OSMetadata != nil { + if module, err := e.OSMetadata.GetModule(); err == nil { + process = module + } + } + pm = &portMeta{ + keep: true, + port: Port{ + Proto: "tcp", + Port: e.Local.Port(), + Process: process, + Pid: e.Pid, + }, + } + im.known[fp] = pm + } + + for k, m := range im.known { + if !m.keep { + delete(im.known, k) + continue + } + ret = append(ret, m.port) + } + + return sortAndDedup(ret), nil +} diff --git a/posture/serialnumber_macos.go b/posture/serialnumber_macos.go index ce0b99683..48355d313 100644 --- a/posture/serialnumber_macos.go +++ b/posture/serialnumber_macos.go @@ -1,74 +1,74 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build cgo && darwin && !ios - -package posture - -// #cgo LDFLAGS: -framework CoreFoundation -framework IOKit -// #include -// #include -// -// #if __MAC_OS_X_VERSION_MIN_REQUIRED < 120000 -// #define kIOMainPortDefault kIOMasterPortDefault -// #endif -// -// const char * -// getSerialNumber() -// { -// CFMutableDictionaryRef matching = IOServiceMatching("IOPlatformExpertDevice"); -// if (!matching) { -// return "err: failed to create dictionary to match IOServices"; -// } -// -// io_service_t service = IOServiceGetMatchingService(kIOMainPortDefault, matching); -// if (!service) { -// return "err: failed to look up registered IOService objects that match a matching dictionary"; -// } -// -// CFStringRef serialNumberRef = IORegistryEntryCreateCFProperty( -// service, -// CFSTR("IOPlatformSerialNumber"), -// kCFAllocatorDefault, -// 0 -// ); -// if (!serialNumberRef) { -// return "err: failed to look up serial number in IORegistry"; -// } -// -// CFIndex length = CFStringGetLength(serialNumberRef); -// CFIndex max_size = CFStringGetMaximumSizeForEncoding(length, kCFStringEncodingUTF8) + 1; -// char *serialNumberBuf = (char *)malloc(max_size); -// -// bool result = CFStringGetCString(serialNumberRef, serialNumberBuf, max_size, kCFStringEncodingUTF8); -// -// CFRelease(serialNumberRef); -// IOObjectRelease(service); -// -// if (!result) { -// free(serialNumberBuf); -// -// return "err: failed to convert serial number reference to string"; -// } -// -// return serialNumberBuf; -// } -import "C" -import ( - "fmt" - "strings" - - "tailscale.com/types/logger" -) - -// GetSerialNumber returns the platform serial sumber as reported by IOKit. -func GetSerialNumbers(_ logger.Logf) ([]string, error) { - csn := C.getSerialNumber() - serialNumber := C.GoString(csn) - - if err, ok := strings.CutPrefix(serialNumber, "err: "); ok { - return nil, fmt.Errorf("failed to get serial number from IOKit: %s", err) - } - - return []string{serialNumber}, nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build cgo && darwin && !ios + +package posture + +// #cgo LDFLAGS: -framework CoreFoundation -framework IOKit +// #include +// #include +// +// #if __MAC_OS_X_VERSION_MIN_REQUIRED < 120000 +// #define kIOMainPortDefault kIOMasterPortDefault +// #endif +// +// const char * +// getSerialNumber() +// { +// CFMutableDictionaryRef matching = IOServiceMatching("IOPlatformExpertDevice"); +// if (!matching) { +// return "err: failed to create dictionary to match IOServices"; +// } +// +// io_service_t service = IOServiceGetMatchingService(kIOMainPortDefault, matching); +// if (!service) { +// return "err: failed to look up registered IOService objects that match a matching dictionary"; +// } +// +// CFStringRef serialNumberRef = IORegistryEntryCreateCFProperty( +// service, +// CFSTR("IOPlatformSerialNumber"), +// kCFAllocatorDefault, +// 0 +// ); +// if (!serialNumberRef) { +// return "err: failed to look up serial number in IORegistry"; +// } +// +// CFIndex length = CFStringGetLength(serialNumberRef); +// CFIndex max_size = CFStringGetMaximumSizeForEncoding(length, kCFStringEncodingUTF8) + 1; +// char *serialNumberBuf = (char *)malloc(max_size); +// +// bool result = CFStringGetCString(serialNumberRef, serialNumberBuf, max_size, kCFStringEncodingUTF8); +// +// CFRelease(serialNumberRef); +// IOObjectRelease(service); +// +// if (!result) { +// free(serialNumberBuf); +// +// return "err: failed to convert serial number reference to string"; +// } +// +// return serialNumberBuf; +// } +import "C" +import ( + "fmt" + "strings" + + "tailscale.com/types/logger" +) + +// GetSerialNumber returns the platform serial sumber as reported by IOKit. +func GetSerialNumbers(_ logger.Logf) ([]string, error) { + csn := C.getSerialNumber() + serialNumber := C.GoString(csn) + + if err, ok := strings.CutPrefix(serialNumber, "err: "); ok { + return nil, fmt.Errorf("failed to get serial number from IOKit: %s", err) + } + + return []string{serialNumber}, nil +} diff --git a/posture/serialnumber_notmacos_test.go b/posture/serialnumber_notmacos_test.go index 8106c34b3..f2a15e037 100644 --- a/posture/serialnumber_notmacos_test.go +++ b/posture/serialnumber_notmacos_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Build on Windows, Linux and *BSD - -//go:build windows || (linux && !android) || freebsd || openbsd || dragonfly || netbsd - -package posture - -import ( - "fmt" - "testing" - - "tailscale.com/types/logger" -) - -func TestGetSerialNumberNotMac(t *testing.T) { - // This test is intentionally skipped as it will - // require root on Linux to get access to the serials. - // The test case is intended for local testing. - // Comment out skip for local testing. - t.Skip() - - sns, err := GetSerialNumbers(logger.Discard) - if err != nil { - t.Fatalf("failed to get serial number: %s", err) - } - - if len(sns) == 0 { - t.Fatalf("expected at least one serial number, got %v", sns) - } - - if len(sns[0]) <= 0 { - t.Errorf("expected a serial number with more than zero characters, got %s", sns[0]) - } - - fmt.Printf("serials: %v\n", sns) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Build on Windows, Linux and *BSD + +//go:build windows || (linux && !android) || freebsd || openbsd || dragonfly || netbsd + +package posture + +import ( + "fmt" + "testing" + + "tailscale.com/types/logger" +) + +func TestGetSerialNumberNotMac(t *testing.T) { + // This test is intentionally skipped as it will + // require root on Linux to get access to the serials. + // The test case is intended for local testing. + // Comment out skip for local testing. + t.Skip() + + sns, err := GetSerialNumbers(logger.Discard) + if err != nil { + t.Fatalf("failed to get serial number: %s", err) + } + + if len(sns) == 0 { + t.Fatalf("expected at least one serial number, got %v", sns) + } + + if len(sns[0]) <= 0 { + t.Errorf("expected a serial number with more than zero characters, got %s", sns[0]) + } + + fmt.Printf("serials: %v\n", sns) +} diff --git a/posture/serialnumber_test.go b/posture/serialnumber_test.go index 1ab819336..fac4392fa 100644 --- a/posture/serialnumber_test.go +++ b/posture/serialnumber_test.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package posture - -import ( - "testing" - - "tailscale.com/types/logger" -) - -func TestGetSerialNumber(t *testing.T) { - // ensure GetSerialNumbers is implemented - // or covered by a stub on a given platform. - _, _ = GetSerialNumbers(logger.Discard) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package posture + +import ( + "testing" + + "tailscale.com/types/logger" +) + +func TestGetSerialNumber(t *testing.T) { + // ensure GetSerialNumbers is implemented + // or covered by a stub on a given platform. + _, _ = GetSerialNumbers(logger.Discard) +} diff --git a/pull-toolchain.sh b/pull-toolchain.sh index 87350ff53..f5a19e7d7 100755 --- a/pull-toolchain.sh +++ b/pull-toolchain.sh @@ -1,16 +1,16 @@ -#!/bin/sh -# Retrieve the latest Go toolchain. -# -set -eu -cd "$(dirname "$0")" - -read -r go_branch go.toolchain.rev -fi - -if [ -n "$(git diff-index --name-only HEAD -- go.toolchain.rev)" ]; then - echo "pull-toolchain.sh: changes imported. Use git commit to make them permanent." >&2 -fi +#!/bin/sh +# Retrieve the latest Go toolchain. +# +set -eu +cd "$(dirname "$0")" + +read -r go_branch go.toolchain.rev +fi + +if [ -n "$(git diff-index --name-only HEAD -- go.toolchain.rev)" ]; then + echo "pull-toolchain.sh: changes imported. Use git commit to make them permanent." >&2 +fi diff --git a/release/deb/debian.postrm.sh b/release/deb/debian.postrm.sh index 93d90b0ea..f4dd4ed9c 100755 --- a/release/deb/debian.postrm.sh +++ b/release/deb/debian.postrm.sh @@ -1,17 +1,17 @@ -#!/bin/sh -set -e -if [ -d /run/systemd/system ] ; then - systemctl --system daemon-reload >/dev/null || true -fi - -if [ -x "/usr/bin/deb-systemd-helper" ]; then - if [ "$1" = "remove" ]; then - deb-systemd-helper mask 'tailscaled.service' >/dev/null || true - fi - - if [ "$1" = "purge" ]; then - deb-systemd-helper purge 'tailscaled.service' >/dev/null || true - deb-systemd-helper unmask 'tailscaled.service' >/dev/null || true - rm -rf /var/lib/tailscale - fi -fi +#!/bin/sh +set -e +if [ -d /run/systemd/system ] ; then + systemctl --system daemon-reload >/dev/null || true +fi + +if [ -x "/usr/bin/deb-systemd-helper" ]; then + if [ "$1" = "remove" ]; then + deb-systemd-helper mask 'tailscaled.service' >/dev/null || true + fi + + if [ "$1" = "purge" ]; then + deb-systemd-helper purge 'tailscaled.service' >/dev/null || true + deb-systemd-helper unmask 'tailscaled.service' >/dev/null || true + rm -rf /var/lib/tailscale + fi +fi diff --git a/release/deb/debian.prerm.sh b/release/deb/debian.prerm.sh index a712a08c8..9be58ede4 100755 --- a/release/deb/debian.prerm.sh +++ b/release/deb/debian.prerm.sh @@ -1,7 +1,7 @@ -#!/bin/sh -set -e -if [ "$1" = "remove" ]; then - if [ -d /run/systemd/system ]; then - deb-systemd-invoke stop 'tailscaled.service' >/dev/null || true - fi -fi +#!/bin/sh +set -e +if [ "$1" = "remove" ]; then + if [ -d /run/systemd/system ]; then + deb-systemd-invoke stop 'tailscaled.service' >/dev/null || true + fi +fi diff --git a/release/dist/memoize.go b/release/dist/memoize.go index f148cd2b7..0927ac0a8 100644 --- a/release/dist/memoize.go +++ b/release/dist/memoize.go @@ -1,86 +1,86 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dist - -import ( - "sync" - - "tailscale.com/util/deephash" -) - -// MemoizedFn is a function that memoize.Do can call. -type MemoizedFn[T any] func() (T, error) - -// Memoize runs MemoizedFns and remembers their results. -type Memoize[O any] struct { - mu sync.Mutex - cond *sync.Cond - outs map[deephash.Sum]O - errs map[deephash.Sum]error - inflight map[deephash.Sum]bool -} - -// Do runs fn and returns its result. -// fn is only run once per unique key. Subsequent Do calls with the same key -// return the memoized result of the first call, even if fn is a different -// function. -func (m *Memoize[O]) Do(key any, fn MemoizedFn[O]) (ret O, err error) { - m.mu.Lock() - defer m.mu.Unlock() - if m.cond == nil { - m.cond = sync.NewCond(&m.mu) - m.outs = map[deephash.Sum]O{} - m.errs = map[deephash.Sum]error{} - m.inflight = map[deephash.Sum]bool{} - } - - k := deephash.Hash(&key) - - for m.inflight[k] { - m.cond.Wait() - } - if err := m.errs[k]; err != nil { - var ret O - return ret, err - } - if ret, ok := m.outs[k]; ok { - return ret, nil - } - - m.inflight[k] = true - m.mu.Unlock() - defer func() { - m.mu.Lock() - delete(m.inflight, k) - if err != nil { - m.errs[k] = err - } else { - m.outs[k] = ret - } - m.cond.Broadcast() - }() - - ret, err = fn() - if err != nil { - var ret O - return ret, err - } - return ret, nil -} - -// once is like memoize, but for functions that don't return non-error values. -type once struct { - m Memoize[any] -} - -// Do runs fn. -// fn is only run once per unique key. Subsequent Do calls with the same key -// return the memoized result of the first call, even if fn is a different -// function. -func (o *once) Do(key any, fn func() error) error { - _, err := o.m.Do(key, func() (any, error) { - return nil, fn() - }) - return err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dist + +import ( + "sync" + + "tailscale.com/util/deephash" +) + +// MemoizedFn is a function that memoize.Do can call. +type MemoizedFn[T any] func() (T, error) + +// Memoize runs MemoizedFns and remembers their results. +type Memoize[O any] struct { + mu sync.Mutex + cond *sync.Cond + outs map[deephash.Sum]O + errs map[deephash.Sum]error + inflight map[deephash.Sum]bool +} + +// Do runs fn and returns its result. +// fn is only run once per unique key. Subsequent Do calls with the same key +// return the memoized result of the first call, even if fn is a different +// function. +func (m *Memoize[O]) Do(key any, fn MemoizedFn[O]) (ret O, err error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.cond == nil { + m.cond = sync.NewCond(&m.mu) + m.outs = map[deephash.Sum]O{} + m.errs = map[deephash.Sum]error{} + m.inflight = map[deephash.Sum]bool{} + } + + k := deephash.Hash(&key) + + for m.inflight[k] { + m.cond.Wait() + } + if err := m.errs[k]; err != nil { + var ret O + return ret, err + } + if ret, ok := m.outs[k]; ok { + return ret, nil + } + + m.inflight[k] = true + m.mu.Unlock() + defer func() { + m.mu.Lock() + delete(m.inflight, k) + if err != nil { + m.errs[k] = err + } else { + m.outs[k] = ret + } + m.cond.Broadcast() + }() + + ret, err = fn() + if err != nil { + var ret O + return ret, err + } + return ret, nil +} + +// once is like memoize, but for functions that don't return non-error values. +type once struct { + m Memoize[any] +} + +// Do runs fn. +// fn is only run once per unique key. Subsequent Do calls with the same key +// return the memoized result of the first call, even if fn is a different +// function. +func (o *once) Do(key any, fn func() error) error { + _, err := o.m.Do(key, func() (any, error) { + return nil, fn() + }) + return err +} diff --git a/release/dist/synology/files/Tailscale.sc b/release/dist/synology/files/Tailscale.sc index f3bb1f0bd..707ac6bb0 100644 --- a/release/dist/synology/files/Tailscale.sc +++ b/release/dist/synology/files/Tailscale.sc @@ -1,6 +1,6 @@ -[Tailscale] -title="Tailscale" -desc="Tailscale VPN" -port_forward="no" -src.ports="41641/udp" +[Tailscale] +title="Tailscale" +desc="Tailscale VPN" +port_forward="no" +src.ports="41641/udp" dst.ports="41641/udp" \ No newline at end of file diff --git a/release/dist/synology/files/config b/release/dist/synology/files/config index 1cf1a6cfa..4dbc48dfb 100644 --- a/release/dist/synology/files/config +++ b/release/dist/synology/files/config @@ -1,11 +1,11 @@ -{ - ".url": { - "SYNO.SDS.Tailscale": { - "type": "url", - "title": "Tailscale", - "icon": "PACKAGE_ICON_256.PNG", - "url": "webman/3rdparty/Tailscale/index.cgi/", - "urlTarget": "_syno_tailscale" - } - } -} +{ + ".url": { + "SYNO.SDS.Tailscale": { + "type": "url", + "title": "Tailscale", + "icon": "PACKAGE_ICON_256.PNG", + "url": "webman/3rdparty/Tailscale/index.cgi/", + "urlTarget": "_syno_tailscale" + } + } +} diff --git a/release/dist/synology/files/index.cgi b/release/dist/synology/files/index.cgi index 996160d1d..2c1990cfd 100755 --- a/release/dist/synology/files/index.cgi +++ b/release/dist/synology/files/index.cgi @@ -1,2 +1,2 @@ -#! /bin/sh -exec /var/packages/Tailscale/target/bin/tailscale web -cgi -prefix="/webman/3rdparty/Tailscale/index.cgi/" +#! /bin/sh +exec /var/packages/Tailscale/target/bin/tailscale web -cgi -prefix="/webman/3rdparty/Tailscale/index.cgi/" diff --git a/release/dist/synology/files/logrotate-dsm6 b/release/dist/synology/files/logrotate-dsm6 index a52a6ba24..2df64283a 100644 --- a/release/dist/synology/files/logrotate-dsm6 +++ b/release/dist/synology/files/logrotate-dsm6 @@ -1,8 +1,8 @@ -/var/packages/Tailscale/etc/tailscaled.stdout.log { - size 10M - rotate 3 - missingok - copytruncate - compress - notifempty -} +/var/packages/Tailscale/etc/tailscaled.stdout.log { + size 10M + rotate 3 + missingok + copytruncate + compress + notifempty +} diff --git a/release/dist/synology/files/logrotate-dsm7 b/release/dist/synology/files/logrotate-dsm7 index 3fe677510..7020dc925 100644 --- a/release/dist/synology/files/logrotate-dsm7 +++ b/release/dist/synology/files/logrotate-dsm7 @@ -1,8 +1,8 @@ -/var/packages/Tailscale/var/tailscaled.stdout.log { - size 10M - rotate 3 - missingok - copytruncate - compress - notifempty -} +/var/packages/Tailscale/var/tailscaled.stdout.log { + size 10M + rotate 3 + missingok + copytruncate + compress + notifempty +} diff --git a/release/dist/synology/files/privilege-dsm6 b/release/dist/synology/files/privilege-dsm6 index c638528d1..4b6fe093a 100644 --- a/release/dist/synology/files/privilege-dsm6 +++ b/release/dist/synology/files/privilege-dsm6 @@ -1,7 +1,7 @@ -{ - "defaults":{ - "run-as": "root" - }, - "username": "tailscale", - "groupname": "tailscale" -} +{ + "defaults":{ + "run-as": "root" + }, + "username": "tailscale", + "groupname": "tailscale" +} diff --git a/release/dist/synology/files/privilege-dsm7 b/release/dist/synology/files/privilege-dsm7 index 4eca66cff..93a9c4f7d 100644 --- a/release/dist/synology/files/privilege-dsm7 +++ b/release/dist/synology/files/privilege-dsm7 @@ -1,7 +1,7 @@ -{ - "defaults":{ - "run-as": "package" - }, - "username": "tailscale", - "groupname": "tailscale" -} +{ + "defaults":{ + "run-as": "package" + }, + "username": "tailscale", + "groupname": "tailscale" +} diff --git a/release/dist/synology/files/privilege-dsm7.for-package-center b/release/dist/synology/files/privilege-dsm7.for-package-center index b2f93cee1..db1468346 100644 --- a/release/dist/synology/files/privilege-dsm7.for-package-center +++ b/release/dist/synology/files/privilege-dsm7.for-package-center @@ -1,13 +1,13 @@ -{ - "defaults":{ - "run-as": "package" - }, - "username": "tailscale", - "groupname": "tailscale", - "tool": [{ - "relpath": "bin/tailscaled", - "user": "package", - "group": "package", - "capabilities": "cap_net_admin,cap_chown,cap_net_raw" - }] -} +{ + "defaults":{ + "run-as": "package" + }, + "username": "tailscale", + "groupname": "tailscale", + "tool": [{ + "relpath": "bin/tailscaled", + "user": "package", + "group": "package", + "capabilities": "cap_net_admin,cap_chown,cap_net_raw" + }] +} diff --git a/release/dist/synology/files/resource b/release/dist/synology/files/resource index 706c97671..0da0002ef 100644 --- a/release/dist/synology/files/resource +++ b/release/dist/synology/files/resource @@ -1,11 +1,11 @@ -{ - "port-config": { - "protocol-file": "conf/Tailscale.sc" - }, - "usr-local-linker": { - "bin": ["bin/tailscale"] - }, - "syslog-config": { - "logrotate-relpath": "conf/logrotate.conf" - } +{ + "port-config": { + "protocol-file": "conf/Tailscale.sc" + }, + "usr-local-linker": { + "bin": ["bin/tailscale"] + }, + "syslog-config": { + "logrotate-relpath": "conf/logrotate.conf" + } } \ No newline at end of file diff --git a/release/dist/synology/files/scripts/postupgrade b/release/dist/synology/files/scripts/postupgrade index 2a7fba5b6..92b94c40c 100644 --- a/release/dist/synology/files/scripts/postupgrade +++ b/release/dist/synology/files/scripts/postupgrade @@ -1,3 +1,3 @@ -#!/bin/sh - +#!/bin/sh + exit 0 \ No newline at end of file diff --git a/release/dist/synology/files/scripts/preupgrade b/release/dist/synology/files/scripts/preupgrade index 2a7fba5b6..92b94c40c 100644 --- a/release/dist/synology/files/scripts/preupgrade +++ b/release/dist/synology/files/scripts/preupgrade @@ -1,3 +1,3 @@ -#!/bin/sh - +#!/bin/sh + exit 0 \ No newline at end of file diff --git a/release/dist/synology/files/scripts/start-stop-status b/release/dist/synology/files/scripts/start-stop-status index 311f9293b..e6ece04e3 100755 --- a/release/dist/synology/files/scripts/start-stop-status +++ b/release/dist/synology/files/scripts/start-stop-status @@ -1,129 +1,129 @@ -#!/bin/bash - -SERVICE_NAME="tailscale" - -if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "6" ]; then - PKGVAR="/var/packages/Tailscale/etc" -else - PKGVAR="${SYNOPKG_PKGVAR}" -fi - -PID_FILE="${PKGVAR}/tailscaled.pid" -LOG_FILE="${PKGVAR}/tailscaled.stdout.log" -STATE_FILE="${PKGVAR}/tailscaled.state" -SOCKET_FILE="${PKGVAR}/tailscaled.sock" -PORT="41641" - -SERVICE_COMMAND="${SYNOPKG_PKGDEST}/bin/tailscaled \ ---state=${STATE_FILE} \ ---socket=${SOCKET_FILE} \ ---port=$PORT" - -if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "7" -a ! -e "/dev/net/tun" ]; then - # TODO(maisem/crawshaw): Disable the tun device in DSM7 for now. - SERVICE_COMMAND="${SERVICE_COMMAND} --tun=userspace-networking" -fi - -if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "6" ]; then - chown -R tailscale:tailscale "${PKGVAR}/" -fi - -start_daemon() { - local ts=$(date --iso-8601=second) - echo "${ts} Starting ${SERVICE_NAME} with: ${SERVICE_COMMAND}" >${LOG_FILE} - STATE_DIRECTORY=${PKGVAR} ${SERVICE_COMMAND} 2>&1 | sed -u '1,200p;201s,.*,[further tailscaled logs suppressed],p;d' >>${LOG_FILE} & - # We pipe tailscaled's output to sed, so "$!" retrieves the PID of sed not tailscaled. - # Use jobs -p to retrieve the PID of the most recent process group leader. - jobs -p >"${PID_FILE}" -} - -stop_daemon() { - if [ -r "${PID_FILE}" ]; then - local PID=$(cat "${PID_FILE}") - local ts=$(date --iso-8601=second) - echo "${ts} Stopping ${SERVICE_NAME} service PID=${PID}" >>${LOG_FILE} - kill -TERM $PID >>${LOG_FILE} 2>&1 - wait_for_status 1 || kill -KILL $PID >>${LOG_FILE} 2>&1 - rm -f "${PID_FILE}" >/dev/null - fi -} - -daemon_status() { - if [ -r "${PID_FILE}" ]; then - local PID=$(cat "${PID_FILE}") - if ps -o pid -p ${PID} > /dev/null; then - return - fi - rm -f "${PID_FILE}" >/dev/null - fi - return 1 -} - -wait_for_status() { - # 20 tries - # sleeps for 1 second after each try - local counter=20 - while [ ${counter} -gt 0 ]; do - daemon_status - [ $? -eq $1 ] && return - counter=$((counter - 1)) - sleep 1 - done - return 1 -} - -ensure_tun_created() { - if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "7" ]; then - # TODO(maisem/crawshaw): Disable the tun device in DSM7 for now. - return - fi - # Create the necessary file structure for /dev/net/tun - if ([ ! -c /dev/net/tun ]); then - if ([ ! -d /dev/net ]); then - mkdir -m 755 /dev/net - fi - mknod /dev/net/tun c 10 200 - chmod 0755 /dev/net/tun - fi - - # Load the tun module if not already loaded - if (!(lsmod | grep -q "^tun\s")); then - insmod /lib/modules/tun.ko - fi -} - -case $1 in -start) - if daemon_status; then - exit 0 - else - ensure_tun_created - start_daemon - exit $? - fi - ;; -stop) - if daemon_status; then - stop_daemon - exit $? - else - exit 0 - fi - ;; -status) - if daemon_status; then - echo "${SERVICE_NAME} is running" - exit 0 - else - echo "${SERVICE_NAME} is not running" - exit 3 - fi - ;; -log) - exit 0 - ;; -*) - echo "command $1 is not implemented" - exit 0 - ;; -esac +#!/bin/bash + +SERVICE_NAME="tailscale" + +if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "6" ]; then + PKGVAR="/var/packages/Tailscale/etc" +else + PKGVAR="${SYNOPKG_PKGVAR}" +fi + +PID_FILE="${PKGVAR}/tailscaled.pid" +LOG_FILE="${PKGVAR}/tailscaled.stdout.log" +STATE_FILE="${PKGVAR}/tailscaled.state" +SOCKET_FILE="${PKGVAR}/tailscaled.sock" +PORT="41641" + +SERVICE_COMMAND="${SYNOPKG_PKGDEST}/bin/tailscaled \ +--state=${STATE_FILE} \ +--socket=${SOCKET_FILE} \ +--port=$PORT" + +if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "7" -a ! -e "/dev/net/tun" ]; then + # TODO(maisem/crawshaw): Disable the tun device in DSM7 for now. + SERVICE_COMMAND="${SERVICE_COMMAND} --tun=userspace-networking" +fi + +if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "6" ]; then + chown -R tailscale:tailscale "${PKGVAR}/" +fi + +start_daemon() { + local ts=$(date --iso-8601=second) + echo "${ts} Starting ${SERVICE_NAME} with: ${SERVICE_COMMAND}" >${LOG_FILE} + STATE_DIRECTORY=${PKGVAR} ${SERVICE_COMMAND} 2>&1 | sed -u '1,200p;201s,.*,[further tailscaled logs suppressed],p;d' >>${LOG_FILE} & + # We pipe tailscaled's output to sed, so "$!" retrieves the PID of sed not tailscaled. + # Use jobs -p to retrieve the PID of the most recent process group leader. + jobs -p >"${PID_FILE}" +} + +stop_daemon() { + if [ -r "${PID_FILE}" ]; then + local PID=$(cat "${PID_FILE}") + local ts=$(date --iso-8601=second) + echo "${ts} Stopping ${SERVICE_NAME} service PID=${PID}" >>${LOG_FILE} + kill -TERM $PID >>${LOG_FILE} 2>&1 + wait_for_status 1 || kill -KILL $PID >>${LOG_FILE} 2>&1 + rm -f "${PID_FILE}" >/dev/null + fi +} + +daemon_status() { + if [ -r "${PID_FILE}" ]; then + local PID=$(cat "${PID_FILE}") + if ps -o pid -p ${PID} > /dev/null; then + return + fi + rm -f "${PID_FILE}" >/dev/null + fi + return 1 +} + +wait_for_status() { + # 20 tries + # sleeps for 1 second after each try + local counter=20 + while [ ${counter} -gt 0 ]; do + daemon_status + [ $? -eq $1 ] && return + counter=$((counter - 1)) + sleep 1 + done + return 1 +} + +ensure_tun_created() { + if [ "${SYNOPKG_DSM_VERSION_MAJOR}" -eq "7" ]; then + # TODO(maisem/crawshaw): Disable the tun device in DSM7 for now. + return + fi + # Create the necessary file structure for /dev/net/tun + if ([ ! -c /dev/net/tun ]); then + if ([ ! -d /dev/net ]); then + mkdir -m 755 /dev/net + fi + mknod /dev/net/tun c 10 200 + chmod 0755 /dev/net/tun + fi + + # Load the tun module if not already loaded + if (!(lsmod | grep -q "^tun\s")); then + insmod /lib/modules/tun.ko + fi +} + +case $1 in +start) + if daemon_status; then + exit 0 + else + ensure_tun_created + start_daemon + exit $? + fi + ;; +stop) + if daemon_status; then + stop_daemon + exit $? + else + exit 0 + fi + ;; +status) + if daemon_status; then + echo "${SERVICE_NAME} is running" + exit 0 + else + echo "${SERVICE_NAME} is not running" + exit 3 + fi + ;; +log) + exit 0 + ;; +*) + echo "command $1 is not implemented" + exit 0 + ;; +esac diff --git a/release/dist/unixpkgs/pkgs.go b/release/dist/unixpkgs/pkgs.go index 60a038eb4..bad6ce572 100644 --- a/release/dist/unixpkgs/pkgs.go +++ b/release/dist/unixpkgs/pkgs.go @@ -1,472 +1,472 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package unixpkgs contains dist Targets for building unix Tailscale packages. -package unixpkgs - -import ( - "archive/tar" - "compress/gzip" - "errors" - "fmt" - "io" - "log" - "os" - "path/filepath" - "strings" - - "github.com/goreleaser/nfpm/v2" - "github.com/goreleaser/nfpm/v2/files" - "tailscale.com/release/dist" -) - -type tgzTarget struct { - filenameArch string // arch to use in filename instead of deriving from goEnv["GOARCH"] - goEnv map[string]string - signer dist.Signer -} - -func (t *tgzTarget) arch() string { - if t.filenameArch != "" { - return t.filenameArch - } - return t.goEnv["GOARCH"] -} - -func (t *tgzTarget) os() string { - return t.goEnv["GOOS"] -} - -func (t *tgzTarget) String() string { - return fmt.Sprintf("%s/%s/tgz", t.os(), t.arch()) -} - -func (t *tgzTarget) Build(b *dist.Build) ([]string, error) { - var filename string - if t.goEnv["GOOS"] == "linux" { - // Linux used to be the only tgz architecture, so we didn't put the OS - // name in the filename. - filename = fmt.Sprintf("tailscale_%s_%s.tgz", b.Version.Short, t.arch()) - } else { - filename = fmt.Sprintf("tailscale_%s_%s_%s.tgz", b.Version.Short, t.os(), t.arch()) - } - if err := b.BuildWebClientAssets(); err != nil { - return nil, err - } - ts, err := b.BuildGoBinary("tailscale.com/cmd/tailscale", t.goEnv) - if err != nil { - return nil, err - } - tsd, err := b.BuildGoBinary("tailscale.com/cmd/tailscaled", t.goEnv) - if err != nil { - return nil, err - } - - log.Printf("Building %s", filename) - - out := filepath.Join(b.Out, filename) - f, err := os.Create(out) - if err != nil { - return nil, err - } - defer f.Close() - gw := gzip.NewWriter(f) - defer gw.Close() - tw := tar.NewWriter(gw) - defer tw.Close() - - addFile := func(src, dst string, mode int64) error { - f, err := os.Open(src) - if err != nil { - return err - } - defer f.Close() - fi, err := f.Stat() - if err != nil { - return err - } - hdr := &tar.Header{ - Name: dst, - Size: fi.Size(), - Mode: mode, - ModTime: b.Time, - Uid: 0, - Gid: 0, - Uname: "root", - Gname: "root", - } - if err := tw.WriteHeader(hdr); err != nil { - return err - } - if _, err = io.Copy(tw, f); err != nil { - return err - } - return nil - } - addDir := func(name string) error { - hdr := &tar.Header{ - Name: name + "/", - Mode: 0755, - ModTime: b.Time, - Uid: 0, - Gid: 0, - Uname: "root", - Gname: "root", - } - return tw.WriteHeader(hdr) - } - dir := strings.TrimSuffix(filename, ".tgz") - if err := addDir(dir); err != nil { - return nil, err - } - if err := addFile(tsd, filepath.Join(dir, "tailscaled"), 0755); err != nil { - return nil, err - } - if err := addFile(ts, filepath.Join(dir, "tailscale"), 0755); err != nil { - return nil, err - } - if t.os() == "linux" { - dir = filepath.Join(dir, "systemd") - if err := addDir(dir); err != nil { - return nil, err - } - tailscaledDir, err := b.GoPkg("tailscale.com/cmd/tailscaled") - if err != nil { - return nil, err - } - if err := addFile(filepath.Join(tailscaledDir, "tailscaled.service"), filepath.Join(dir, "tailscaled.service"), 0644); err != nil { - return nil, err - } - if err := addFile(filepath.Join(tailscaledDir, "tailscaled.defaults"), filepath.Join(dir, "tailscaled.defaults"), 0644); err != nil { - return nil, err - } - } - if err := tw.Close(); err != nil { - return nil, err - } - if err := gw.Close(); err != nil { - return nil, err - } - if err := f.Close(); err != nil { - return nil, err - } - - files := []string{filename} - - if t.signer != nil { - outSig := out + ".sig" - if err := t.signer.SignFile(out, outSig); err != nil { - return nil, err - } - files = append(files, filepath.Base(outSig)) - } - - return files, nil -} - -type debTarget struct { - goEnv map[string]string -} - -func (t *debTarget) os() string { - return t.goEnv["GOOS"] -} - -func (t *debTarget) arch() string { - return t.goEnv["GOARCH"] -} - -func (t *debTarget) String() string { - return fmt.Sprintf("linux/%s/deb", t.goEnv["GOARCH"]) -} - -func (t *debTarget) Build(b *dist.Build) ([]string, error) { - if t.os() != "linux" { - return nil, errors.New("deb only supported on linux") - } - - if err := b.BuildWebClientAssets(); err != nil { - return nil, err - } - ts, err := b.BuildGoBinary("tailscale.com/cmd/tailscale", t.goEnv) - if err != nil { - return nil, err - } - tsd, err := b.BuildGoBinary("tailscale.com/cmd/tailscaled", t.goEnv) - if err != nil { - return nil, err - } - - tailscaledDir, err := b.GoPkg("tailscale.com/cmd/tailscaled") - if err != nil { - return nil, err - } - repoDir, err := b.GoPkg("tailscale.com") - if err != nil { - return nil, err - } - - arch := debArch(t.arch()) - contents, err := files.PrepareForPackager(files.Contents{ - &files.Content{ - Type: files.TypeFile, - Source: ts, - Destination: "/usr/bin/tailscale", - }, - &files.Content{ - Type: files.TypeFile, - Source: tsd, - Destination: "/usr/sbin/tailscaled", - }, - &files.Content{ - Type: files.TypeFile, - Source: filepath.Join(tailscaledDir, "tailscaled.service"), - Destination: "/lib/systemd/system/tailscaled.service", - }, - &files.Content{ - Type: files.TypeConfigNoReplace, - Source: filepath.Join(tailscaledDir, "tailscaled.defaults"), - Destination: "/etc/default/tailscaled", - }, - }, 0, "deb", false) - if err != nil { - return nil, err - } - info := nfpm.WithDefaults(&nfpm.Info{ - Name: "tailscale", - Arch: arch, - Platform: "linux", - Version: b.Version.Short, - Maintainer: "Tailscale Inc ", - Description: "The easiest, most secure, cross platform way to use WireGuard + oauth2 + 2FA/SSO", - Homepage: "https://www.tailscale.com", - License: "MIT", - Section: "net", - Priority: "extra", - Overridables: nfpm.Overridables{ - Contents: contents, - Scripts: nfpm.Scripts{ - PostInstall: filepath.Join(repoDir, "release/deb/debian.postinst.sh"), - PreRemove: filepath.Join(repoDir, "release/deb/debian.prerm.sh"), - PostRemove: filepath.Join(repoDir, "release/deb/debian.postrm.sh"), - }, - Depends: []string{ - // iptables is almost always required but not strictly needed. - // Even if you can technically run Tailscale without it (by - // manually configuring nftables or userspace mode), we still - // mark this as "Depends" because our previous experiment in - // https://github.com/tailscale/tailscale/issues/9236 of making - // it only Recommends caused too many problems. Until our - // nftables table is more mature, we'd rather err on the side of - // wasting a little disk by including iptables for people who - // might not need it rather than handle reports of it being - // missing. - "iptables", - }, - Recommends: []string{ - "tailscale-archive-keyring (>= 1.35.181)", - // The "ip" command isn't needed since 2021-11-01 in - // 408b0923a61972ed but kept as an option as of - // 2021-11-18 in d24ed3f68e35e802d531371. See - // https://github.com/tailscale/tailscale/issues/391. - // We keep it recommended because it's usually - // installed anyway and it's useful for debugging. But - // we can live without it, so it's not Depends. - "iproute2", - }, - Replaces: []string{"tailscale-relay"}, - Conflicts: []string{"tailscale-relay"}, - }, - }) - pkg, err := nfpm.Get("deb") - if err != nil { - return nil, err - } - - filename := fmt.Sprintf("tailscale_%s_%s.deb", b.Version.Short, arch) - log.Printf("Building %s", filename) - f, err := os.Create(filepath.Join(b.Out, filename)) - if err != nil { - return nil, err - } - defer f.Close() - if err := pkg.Package(info, f); err != nil { - return nil, err - } - if err := f.Close(); err != nil { - return nil, err - } - - return []string{filename}, nil -} - -type rpmTarget struct { - goEnv map[string]string - signer dist.Signer -} - -func (t *rpmTarget) os() string { - return t.goEnv["GOOS"] -} - -func (t *rpmTarget) arch() string { - return t.goEnv["GOARCH"] -} - -func (t *rpmTarget) String() string { - return fmt.Sprintf("linux/%s/rpm", t.arch()) -} - -func (t *rpmTarget) Build(b *dist.Build) ([]string, error) { - if t.os() != "linux" { - return nil, errors.New("rpm only supported on linux") - } - - if err := b.BuildWebClientAssets(); err != nil { - return nil, err - } - ts, err := b.BuildGoBinary("tailscale.com/cmd/tailscale", t.goEnv) - if err != nil { - return nil, err - } - tsd, err := b.BuildGoBinary("tailscale.com/cmd/tailscaled", t.goEnv) - if err != nil { - return nil, err - } - - tailscaledDir, err := b.GoPkg("tailscale.com/cmd/tailscaled") - if err != nil { - return nil, err - } - repoDir, err := b.GoPkg("tailscale.com") - if err != nil { - return nil, err - } - - arch := rpmArch(t.arch()) - contents, err := files.PrepareForPackager(files.Contents{ - &files.Content{ - Type: files.TypeFile, - Source: ts, - Destination: "/usr/bin/tailscale", - }, - &files.Content{ - Type: files.TypeFile, - Source: tsd, - Destination: "/usr/sbin/tailscaled", - }, - &files.Content{ - Type: files.TypeFile, - Source: filepath.Join(tailscaledDir, "tailscaled.service"), - Destination: "/lib/systemd/system/tailscaled.service", - }, - &files.Content{ - Type: files.TypeConfigNoReplace, - Source: filepath.Join(tailscaledDir, "tailscaled.defaults"), - Destination: "/etc/default/tailscaled", - }, - // SELinux policy on e.g. CentOS 8 forbids writing to /var/cache. - // Creating an empty directory at install time resolves this issue. - &files.Content{ - Type: files.TypeDir, - Destination: "/var/cache/tailscale", - }, - }, 0, "rpm", false) - if err != nil { - return nil, err - } - info := nfpm.WithDefaults(&nfpm.Info{ - Name: "tailscale", - Arch: arch, - Platform: "linux", - Version: b.Version.Short, - Maintainer: "Tailscale Inc ", - Description: "The easiest, most secure, cross platform way to use WireGuard + oauth2 + 2FA/SSO", - Homepage: "https://www.tailscale.com", - License: "MIT", - Overridables: nfpm.Overridables{ - Contents: contents, - Scripts: nfpm.Scripts{ - PostInstall: filepath.Join(repoDir, "release/rpm/rpm.postinst.sh"), - PreRemove: filepath.Join(repoDir, "release/rpm/rpm.prerm.sh"), - PostRemove: filepath.Join(repoDir, "release/rpm/rpm.postrm.sh"), - }, - Depends: []string{"iptables", "iproute"}, - Replaces: []string{"tailscale-relay"}, - Conflicts: []string{"tailscale-relay"}, - RPM: nfpm.RPM{ - Group: "Network", - Signature: nfpm.RPMSignature{ - PackageSignature: nfpm.PackageSignature{ - SignFn: t.signer, - }, - }, - }, - }, - }) - pkg, err := nfpm.Get("rpm") - if err != nil { - return nil, err - } - - filename := fmt.Sprintf("tailscale_%s_%s.rpm", b.Version.Short, arch) - log.Printf("Building %s", filename) - - f, err := os.Create(filepath.Join(b.Out, filename)) - if err != nil { - return nil, err - } - defer f.Close() - if err := pkg.Package(info, f); err != nil { - return nil, err - } - if err := f.Close(); err != nil { - return nil, err - } - - return []string{filename}, nil -} - -// debArch returns the debian arch name for the given Go arch name. -// nfpm also does this translation internally, but we need to do it outside nfpm -// because we also need the filename to be correct. -func debArch(arch string) string { - switch arch { - case "386": - return "i386" - case "arm": - // TODO: this is supposed to be "armel" for GOARM=5, and "armhf" for - // GOARM=6 and 7. But we have some tech debt to pay off here before we - // can ship more than 1 ARM deb, so for now match redo's behavior of - // shipping armv5 binaries in an armv7 trenchcoat. - return "armhf" - case "mipsle": - return "mipsel" - case "mips64le": - return "mips64el" - default: - return arch - } -} - -// rpmArch returns the RPM arch name for the given Go arch name. -// nfpm also does this translation internally, but we need to do it outside nfpm -// because we also need the filename to be correct. -func rpmArch(arch string) string { - switch arch { - case "amd64": - return "x86_64" - case "386": - return "i386" - case "arm": - return "armv7hl" - case "arm64": - return "aarch64" - case "mipsle": - return "mipsel" - case "mips64le": - return "mips64el" - default: - return arch - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package unixpkgs contains dist Targets for building unix Tailscale packages. +package unixpkgs + +import ( + "archive/tar" + "compress/gzip" + "errors" + "fmt" + "io" + "log" + "os" + "path/filepath" + "strings" + + "github.com/goreleaser/nfpm/v2" + "github.com/goreleaser/nfpm/v2/files" + "tailscale.com/release/dist" +) + +type tgzTarget struct { + filenameArch string // arch to use in filename instead of deriving from goEnv["GOARCH"] + goEnv map[string]string + signer dist.Signer +} + +func (t *tgzTarget) arch() string { + if t.filenameArch != "" { + return t.filenameArch + } + return t.goEnv["GOARCH"] +} + +func (t *tgzTarget) os() string { + return t.goEnv["GOOS"] +} + +func (t *tgzTarget) String() string { + return fmt.Sprintf("%s/%s/tgz", t.os(), t.arch()) +} + +func (t *tgzTarget) Build(b *dist.Build) ([]string, error) { + var filename string + if t.goEnv["GOOS"] == "linux" { + // Linux used to be the only tgz architecture, so we didn't put the OS + // name in the filename. + filename = fmt.Sprintf("tailscale_%s_%s.tgz", b.Version.Short, t.arch()) + } else { + filename = fmt.Sprintf("tailscale_%s_%s_%s.tgz", b.Version.Short, t.os(), t.arch()) + } + if err := b.BuildWebClientAssets(); err != nil { + return nil, err + } + ts, err := b.BuildGoBinary("tailscale.com/cmd/tailscale", t.goEnv) + if err != nil { + return nil, err + } + tsd, err := b.BuildGoBinary("tailscale.com/cmd/tailscaled", t.goEnv) + if err != nil { + return nil, err + } + + log.Printf("Building %s", filename) + + out := filepath.Join(b.Out, filename) + f, err := os.Create(out) + if err != nil { + return nil, err + } + defer f.Close() + gw := gzip.NewWriter(f) + defer gw.Close() + tw := tar.NewWriter(gw) + defer tw.Close() + + addFile := func(src, dst string, mode int64) error { + f, err := os.Open(src) + if err != nil { + return err + } + defer f.Close() + fi, err := f.Stat() + if err != nil { + return err + } + hdr := &tar.Header{ + Name: dst, + Size: fi.Size(), + Mode: mode, + ModTime: b.Time, + Uid: 0, + Gid: 0, + Uname: "root", + Gname: "root", + } + if err := tw.WriteHeader(hdr); err != nil { + return err + } + if _, err = io.Copy(tw, f); err != nil { + return err + } + return nil + } + addDir := func(name string) error { + hdr := &tar.Header{ + Name: name + "/", + Mode: 0755, + ModTime: b.Time, + Uid: 0, + Gid: 0, + Uname: "root", + Gname: "root", + } + return tw.WriteHeader(hdr) + } + dir := strings.TrimSuffix(filename, ".tgz") + if err := addDir(dir); err != nil { + return nil, err + } + if err := addFile(tsd, filepath.Join(dir, "tailscaled"), 0755); err != nil { + return nil, err + } + if err := addFile(ts, filepath.Join(dir, "tailscale"), 0755); err != nil { + return nil, err + } + if t.os() == "linux" { + dir = filepath.Join(dir, "systemd") + if err := addDir(dir); err != nil { + return nil, err + } + tailscaledDir, err := b.GoPkg("tailscale.com/cmd/tailscaled") + if err != nil { + return nil, err + } + if err := addFile(filepath.Join(tailscaledDir, "tailscaled.service"), filepath.Join(dir, "tailscaled.service"), 0644); err != nil { + return nil, err + } + if err := addFile(filepath.Join(tailscaledDir, "tailscaled.defaults"), filepath.Join(dir, "tailscaled.defaults"), 0644); err != nil { + return nil, err + } + } + if err := tw.Close(); err != nil { + return nil, err + } + if err := gw.Close(); err != nil { + return nil, err + } + if err := f.Close(); err != nil { + return nil, err + } + + files := []string{filename} + + if t.signer != nil { + outSig := out + ".sig" + if err := t.signer.SignFile(out, outSig); err != nil { + return nil, err + } + files = append(files, filepath.Base(outSig)) + } + + return files, nil +} + +type debTarget struct { + goEnv map[string]string +} + +func (t *debTarget) os() string { + return t.goEnv["GOOS"] +} + +func (t *debTarget) arch() string { + return t.goEnv["GOARCH"] +} + +func (t *debTarget) String() string { + return fmt.Sprintf("linux/%s/deb", t.goEnv["GOARCH"]) +} + +func (t *debTarget) Build(b *dist.Build) ([]string, error) { + if t.os() != "linux" { + return nil, errors.New("deb only supported on linux") + } + + if err := b.BuildWebClientAssets(); err != nil { + return nil, err + } + ts, err := b.BuildGoBinary("tailscale.com/cmd/tailscale", t.goEnv) + if err != nil { + return nil, err + } + tsd, err := b.BuildGoBinary("tailscale.com/cmd/tailscaled", t.goEnv) + if err != nil { + return nil, err + } + + tailscaledDir, err := b.GoPkg("tailscale.com/cmd/tailscaled") + if err != nil { + return nil, err + } + repoDir, err := b.GoPkg("tailscale.com") + if err != nil { + return nil, err + } + + arch := debArch(t.arch()) + contents, err := files.PrepareForPackager(files.Contents{ + &files.Content{ + Type: files.TypeFile, + Source: ts, + Destination: "/usr/bin/tailscale", + }, + &files.Content{ + Type: files.TypeFile, + Source: tsd, + Destination: "/usr/sbin/tailscaled", + }, + &files.Content{ + Type: files.TypeFile, + Source: filepath.Join(tailscaledDir, "tailscaled.service"), + Destination: "/lib/systemd/system/tailscaled.service", + }, + &files.Content{ + Type: files.TypeConfigNoReplace, + Source: filepath.Join(tailscaledDir, "tailscaled.defaults"), + Destination: "/etc/default/tailscaled", + }, + }, 0, "deb", false) + if err != nil { + return nil, err + } + info := nfpm.WithDefaults(&nfpm.Info{ + Name: "tailscale", + Arch: arch, + Platform: "linux", + Version: b.Version.Short, + Maintainer: "Tailscale Inc ", + Description: "The easiest, most secure, cross platform way to use WireGuard + oauth2 + 2FA/SSO", + Homepage: "https://www.tailscale.com", + License: "MIT", + Section: "net", + Priority: "extra", + Overridables: nfpm.Overridables{ + Contents: contents, + Scripts: nfpm.Scripts{ + PostInstall: filepath.Join(repoDir, "release/deb/debian.postinst.sh"), + PreRemove: filepath.Join(repoDir, "release/deb/debian.prerm.sh"), + PostRemove: filepath.Join(repoDir, "release/deb/debian.postrm.sh"), + }, + Depends: []string{ + // iptables is almost always required but not strictly needed. + // Even if you can technically run Tailscale without it (by + // manually configuring nftables or userspace mode), we still + // mark this as "Depends" because our previous experiment in + // https://github.com/tailscale/tailscale/issues/9236 of making + // it only Recommends caused too many problems. Until our + // nftables table is more mature, we'd rather err on the side of + // wasting a little disk by including iptables for people who + // might not need it rather than handle reports of it being + // missing. + "iptables", + }, + Recommends: []string{ + "tailscale-archive-keyring (>= 1.35.181)", + // The "ip" command isn't needed since 2021-11-01 in + // 408b0923a61972ed but kept as an option as of + // 2021-11-18 in d24ed3f68e35e802d531371. See + // https://github.com/tailscale/tailscale/issues/391. + // We keep it recommended because it's usually + // installed anyway and it's useful for debugging. But + // we can live without it, so it's not Depends. + "iproute2", + }, + Replaces: []string{"tailscale-relay"}, + Conflicts: []string{"tailscale-relay"}, + }, + }) + pkg, err := nfpm.Get("deb") + if err != nil { + return nil, err + } + + filename := fmt.Sprintf("tailscale_%s_%s.deb", b.Version.Short, arch) + log.Printf("Building %s", filename) + f, err := os.Create(filepath.Join(b.Out, filename)) + if err != nil { + return nil, err + } + defer f.Close() + if err := pkg.Package(info, f); err != nil { + return nil, err + } + if err := f.Close(); err != nil { + return nil, err + } + + return []string{filename}, nil +} + +type rpmTarget struct { + goEnv map[string]string + signer dist.Signer +} + +func (t *rpmTarget) os() string { + return t.goEnv["GOOS"] +} + +func (t *rpmTarget) arch() string { + return t.goEnv["GOARCH"] +} + +func (t *rpmTarget) String() string { + return fmt.Sprintf("linux/%s/rpm", t.arch()) +} + +func (t *rpmTarget) Build(b *dist.Build) ([]string, error) { + if t.os() != "linux" { + return nil, errors.New("rpm only supported on linux") + } + + if err := b.BuildWebClientAssets(); err != nil { + return nil, err + } + ts, err := b.BuildGoBinary("tailscale.com/cmd/tailscale", t.goEnv) + if err != nil { + return nil, err + } + tsd, err := b.BuildGoBinary("tailscale.com/cmd/tailscaled", t.goEnv) + if err != nil { + return nil, err + } + + tailscaledDir, err := b.GoPkg("tailscale.com/cmd/tailscaled") + if err != nil { + return nil, err + } + repoDir, err := b.GoPkg("tailscale.com") + if err != nil { + return nil, err + } + + arch := rpmArch(t.arch()) + contents, err := files.PrepareForPackager(files.Contents{ + &files.Content{ + Type: files.TypeFile, + Source: ts, + Destination: "/usr/bin/tailscale", + }, + &files.Content{ + Type: files.TypeFile, + Source: tsd, + Destination: "/usr/sbin/tailscaled", + }, + &files.Content{ + Type: files.TypeFile, + Source: filepath.Join(tailscaledDir, "tailscaled.service"), + Destination: "/lib/systemd/system/tailscaled.service", + }, + &files.Content{ + Type: files.TypeConfigNoReplace, + Source: filepath.Join(tailscaledDir, "tailscaled.defaults"), + Destination: "/etc/default/tailscaled", + }, + // SELinux policy on e.g. CentOS 8 forbids writing to /var/cache. + // Creating an empty directory at install time resolves this issue. + &files.Content{ + Type: files.TypeDir, + Destination: "/var/cache/tailscale", + }, + }, 0, "rpm", false) + if err != nil { + return nil, err + } + info := nfpm.WithDefaults(&nfpm.Info{ + Name: "tailscale", + Arch: arch, + Platform: "linux", + Version: b.Version.Short, + Maintainer: "Tailscale Inc ", + Description: "The easiest, most secure, cross platform way to use WireGuard + oauth2 + 2FA/SSO", + Homepage: "https://www.tailscale.com", + License: "MIT", + Overridables: nfpm.Overridables{ + Contents: contents, + Scripts: nfpm.Scripts{ + PostInstall: filepath.Join(repoDir, "release/rpm/rpm.postinst.sh"), + PreRemove: filepath.Join(repoDir, "release/rpm/rpm.prerm.sh"), + PostRemove: filepath.Join(repoDir, "release/rpm/rpm.postrm.sh"), + }, + Depends: []string{"iptables", "iproute"}, + Replaces: []string{"tailscale-relay"}, + Conflicts: []string{"tailscale-relay"}, + RPM: nfpm.RPM{ + Group: "Network", + Signature: nfpm.RPMSignature{ + PackageSignature: nfpm.PackageSignature{ + SignFn: t.signer, + }, + }, + }, + }, + }) + pkg, err := nfpm.Get("rpm") + if err != nil { + return nil, err + } + + filename := fmt.Sprintf("tailscale_%s_%s.rpm", b.Version.Short, arch) + log.Printf("Building %s", filename) + + f, err := os.Create(filepath.Join(b.Out, filename)) + if err != nil { + return nil, err + } + defer f.Close() + if err := pkg.Package(info, f); err != nil { + return nil, err + } + if err := f.Close(); err != nil { + return nil, err + } + + return []string{filename}, nil +} + +// debArch returns the debian arch name for the given Go arch name. +// nfpm also does this translation internally, but we need to do it outside nfpm +// because we also need the filename to be correct. +func debArch(arch string) string { + switch arch { + case "386": + return "i386" + case "arm": + // TODO: this is supposed to be "armel" for GOARM=5, and "armhf" for + // GOARM=6 and 7. But we have some tech debt to pay off here before we + // can ship more than 1 ARM deb, so for now match redo's behavior of + // shipping armv5 binaries in an armv7 trenchcoat. + return "armhf" + case "mipsle": + return "mipsel" + case "mips64le": + return "mips64el" + default: + return arch + } +} + +// rpmArch returns the RPM arch name for the given Go arch name. +// nfpm also does this translation internally, but we need to do it outside nfpm +// because we also need the filename to be correct. +func rpmArch(arch string) string { + switch arch { + case "amd64": + return "x86_64" + case "386": + return "i386" + case "arm": + return "armv7hl" + case "arm64": + return "aarch64" + case "mipsle": + return "mipsel" + case "mips64le": + return "mips64el" + default: + return arch + } +} diff --git a/release/dist/unixpkgs/targets.go b/release/dist/unixpkgs/targets.go index f87c56d31..42bab6d3b 100644 --- a/release/dist/unixpkgs/targets.go +++ b/release/dist/unixpkgs/targets.go @@ -1,127 +1,127 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package unixpkgs - -import ( - "fmt" - "sort" - "strings" - - "tailscale.com/release/dist" - - _ "github.com/goreleaser/nfpm/v2/deb" - _ "github.com/goreleaser/nfpm/v2/rpm" -) - -type Signers struct { - Tarball dist.Signer - RPM dist.Signer -} - -func Targets(signers Signers) []dist.Target { - var ret []dist.Target - for goosgoarch := range tarballs { - goos, goarch := splitGoosGoarch(goosgoarch) - ret = append(ret, &tgzTarget{ - goEnv: map[string]string{ - "GOOS": goos, - "GOARCH": goarch, - }, - signer: signers.Tarball, - }) - } - for goosgoarch := range debs { - goos, goarch := splitGoosGoarch(goosgoarch) - ret = append(ret, &debTarget{ - goEnv: map[string]string{ - "GOOS": goos, - "GOARCH": goarch, - }, - }) - } - for goosgoarch := range rpms { - goos, goarch := splitGoosGoarch(goosgoarch) - ret = append(ret, &rpmTarget{ - goEnv: map[string]string{ - "GOOS": goos, - "GOARCH": goarch, - }, - signer: signers.RPM, - }) - } - - // Special case: AMD Geode is 386 with softfloat. Tarballs only since it's - // an ancient architecture. - ret = append(ret, &tgzTarget{ - filenameArch: "geode", - goEnv: map[string]string{ - "GOOS": "linux", - "GOARCH": "386", - "GO386": "softfloat", - }, - signer: signers.Tarball, - }) - - sort.Slice(ret, func(i, j int) bool { - return ret[i].String() < ret[j].String() - }) - - return ret -} - -var ( - tarballs = map[string]bool{ - "linux/386": true, - "linux/amd64": true, - "linux/arm": true, - "linux/arm64": true, - "linux/mips64": true, - "linux/mips64le": true, - "linux/mips": true, - "linux/mipsle": true, - "linux/riscv64": true, - // TODO: more tarballs we could distribute, but don't currently. Leaving - // out for initial parity with redo. - // "darwin/amd64": true, - // "darwin/arm64": true, - // "freebsd/amd64": true, - // "openbsd/amd64": true, - } - - debs = map[string]bool{ - "linux/386": true, - "linux/amd64": true, - "linux/arm": true, - "linux/arm64": true, - "linux/riscv64": true, - "linux/mipsle": true, - "linux/mips64le": true, - "linux/mips": true, - // Debian does not support big endian mips64. Leave that out until we know - // we need it. - // "linux/mips64": true, - } - - rpms = map[string]bool{ - "linux/386": true, - "linux/amd64": true, - "linux/arm": true, - "linux/arm64": true, - "linux/riscv64": true, - "linux/mipsle": true, - "linux/mips64le": true, - // Fedora only supports little endian mipses. Maybe some other distribution - // supports big-endian? Leave them out for now. - // "linux/mips": true, - // "linux/mips64": true, - } -) - -func splitGoosGoarch(s string) (string, string) { - goos, goarch, ok := strings.Cut(s, "/") - if !ok { - panic(fmt.Sprintf("invalid target %q", s)) - } - return goos, goarch -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package unixpkgs + +import ( + "fmt" + "sort" + "strings" + + "tailscale.com/release/dist" + + _ "github.com/goreleaser/nfpm/v2/deb" + _ "github.com/goreleaser/nfpm/v2/rpm" +) + +type Signers struct { + Tarball dist.Signer + RPM dist.Signer +} + +func Targets(signers Signers) []dist.Target { + var ret []dist.Target + for goosgoarch := range tarballs { + goos, goarch := splitGoosGoarch(goosgoarch) + ret = append(ret, &tgzTarget{ + goEnv: map[string]string{ + "GOOS": goos, + "GOARCH": goarch, + }, + signer: signers.Tarball, + }) + } + for goosgoarch := range debs { + goos, goarch := splitGoosGoarch(goosgoarch) + ret = append(ret, &debTarget{ + goEnv: map[string]string{ + "GOOS": goos, + "GOARCH": goarch, + }, + }) + } + for goosgoarch := range rpms { + goos, goarch := splitGoosGoarch(goosgoarch) + ret = append(ret, &rpmTarget{ + goEnv: map[string]string{ + "GOOS": goos, + "GOARCH": goarch, + }, + signer: signers.RPM, + }) + } + + // Special case: AMD Geode is 386 with softfloat. Tarballs only since it's + // an ancient architecture. + ret = append(ret, &tgzTarget{ + filenameArch: "geode", + goEnv: map[string]string{ + "GOOS": "linux", + "GOARCH": "386", + "GO386": "softfloat", + }, + signer: signers.Tarball, + }) + + sort.Slice(ret, func(i, j int) bool { + return ret[i].String() < ret[j].String() + }) + + return ret +} + +var ( + tarballs = map[string]bool{ + "linux/386": true, + "linux/amd64": true, + "linux/arm": true, + "linux/arm64": true, + "linux/mips64": true, + "linux/mips64le": true, + "linux/mips": true, + "linux/mipsle": true, + "linux/riscv64": true, + // TODO: more tarballs we could distribute, but don't currently. Leaving + // out for initial parity with redo. + // "darwin/amd64": true, + // "darwin/arm64": true, + // "freebsd/amd64": true, + // "openbsd/amd64": true, + } + + debs = map[string]bool{ + "linux/386": true, + "linux/amd64": true, + "linux/arm": true, + "linux/arm64": true, + "linux/riscv64": true, + "linux/mipsle": true, + "linux/mips64le": true, + "linux/mips": true, + // Debian does not support big endian mips64. Leave that out until we know + // we need it. + // "linux/mips64": true, + } + + rpms = map[string]bool{ + "linux/386": true, + "linux/amd64": true, + "linux/arm": true, + "linux/arm64": true, + "linux/riscv64": true, + "linux/mipsle": true, + "linux/mips64le": true, + // Fedora only supports little endian mipses. Maybe some other distribution + // supports big-endian? Leave them out for now. + // "linux/mips": true, + // "linux/mips64": true, + } +) + +func splitGoosGoarch(s string) (string, string) { + goos, goarch, ok := strings.Cut(s, "/") + if !ok { + panic(fmt.Sprintf("invalid target %q", s)) + } + return goos, goarch +} diff --git a/release/release.go b/release/release.go index 638635b6d..a8d0e6b62 100644 --- a/release/release.go +++ b/release/release.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package release provides functionality for building client releases. -package release - -import "embed" - -// This contains all files in the release directory, -// notably the files needed for deb, rpm, and similar packages. -// Because we assign this to the blank identifier, it does not actually embed the files. -// However, this does cause `go mod vendor` to include the files when vendoring the package. -// -//go:embed * -var _ embed.FS +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package release provides functionality for building client releases. +package release + +import "embed" + +// This contains all files in the release directory, +// notably the files needed for deb, rpm, and similar packages. +// Because we assign this to the blank identifier, it does not actually embed the files. +// However, this does cause `go mod vendor` to include the files when vendoring the package. +// +//go:embed * +var _ embed.FS diff --git a/release/rpm/rpm.postinst.sh b/release/rpm/rpm.postinst.sh index f9c1fddfd..3d264c5f6 100755 --- a/release/rpm/rpm.postinst.sh +++ b/release/rpm/rpm.postinst.sh @@ -1,41 +1,41 @@ -# $1 == 1 for initial installation. -# $1 == 2 for upgrades. - -if [ $1 -eq 1 ] ; then - # Normally, the tailscale-relay package would request shutdown of - # its service before uninstallation. Unfortunately, the - # tailscale-relay package we distributed doesn't have those - # scriptlets. We definitely want relaynode to be stopped when - # installing tailscaled though, so we blindly try to turn off - # relaynode here. - # - # However, we also want this package installation to look like an - # upgrade from relaynode! Therefore, if relaynode is currently - # enabled, we want to also enable tailscaled. If relaynode is - # currently running, we also want to start tailscaled. - # - # If there doesn't seem to be an active or enabled relaynode on - # the system, we follow the RPM convention for package installs, - # which is to not enable or start the service. - relaynode_enabled=0 - relaynode_running=0 - if systemctl is-enabled tailscale-relay.service >/dev/null 2>&1; then - relaynode_enabled=1 - fi - if systemctl is-active tailscale-relay.service >/dev/null 2>&1; then - relaynode_running=1 - fi - - systemctl --no-reload disable tailscale-relay.service >/dev/null 2>&1 || : - systemctl stop tailscale-relay.service >/dev/null 2>&1 || : - - if [ $relaynode_enabled -eq 1 ]; then - systemctl enable tailscaled.service >/dev/null 2>&1 || : - else - systemctl preset tailscaled.service >/dev/null 2>&1 || : - fi - - if [ $relaynode_running -eq 1 ]; then - systemctl start tailscaled.service >/dev/null 2>&1 || : - fi -fi +# $1 == 1 for initial installation. +# $1 == 2 for upgrades. + +if [ $1 -eq 1 ] ; then + # Normally, the tailscale-relay package would request shutdown of + # its service before uninstallation. Unfortunately, the + # tailscale-relay package we distributed doesn't have those + # scriptlets. We definitely want relaynode to be stopped when + # installing tailscaled though, so we blindly try to turn off + # relaynode here. + # + # However, we also want this package installation to look like an + # upgrade from relaynode! Therefore, if relaynode is currently + # enabled, we want to also enable tailscaled. If relaynode is + # currently running, we also want to start tailscaled. + # + # If there doesn't seem to be an active or enabled relaynode on + # the system, we follow the RPM convention for package installs, + # which is to not enable or start the service. + relaynode_enabled=0 + relaynode_running=0 + if systemctl is-enabled tailscale-relay.service >/dev/null 2>&1; then + relaynode_enabled=1 + fi + if systemctl is-active tailscale-relay.service >/dev/null 2>&1; then + relaynode_running=1 + fi + + systemctl --no-reload disable tailscale-relay.service >/dev/null 2>&1 || : + systemctl stop tailscale-relay.service >/dev/null 2>&1 || : + + if [ $relaynode_enabled -eq 1 ]; then + systemctl enable tailscaled.service >/dev/null 2>&1 || : + else + systemctl preset tailscaled.service >/dev/null 2>&1 || : + fi + + if [ $relaynode_running -eq 1 ]; then + systemctl start tailscaled.service >/dev/null 2>&1 || : + fi +fi diff --git a/release/rpm/rpm.postrm.sh b/release/rpm/rpm.postrm.sh index e19a7305c..d74f7e9de 100755 --- a/release/rpm/rpm.postrm.sh +++ b/release/rpm/rpm.postrm.sh @@ -1,8 +1,8 @@ -# $1 == 0 for uninstallation. -# $1 == 1 for removing old package during upgrade. - -systemctl daemon-reload >/dev/null 2>&1 || : -if [ $1 -ge 1 ] ; then - # Package upgrade, not uninstall - systemctl try-restart tailscaled.service >/dev/null 2>&1 || : -fi +# $1 == 0 for uninstallation. +# $1 == 1 for removing old package during upgrade. + +systemctl daemon-reload >/dev/null 2>&1 || : +if [ $1 -ge 1 ] ; then + # Package upgrade, not uninstall + systemctl try-restart tailscaled.service >/dev/null 2>&1 || : +fi diff --git a/release/rpm/rpm.prerm.sh b/release/rpm/rpm.prerm.sh index eeabf3b58..682c01bd5 100755 --- a/release/rpm/rpm.prerm.sh +++ b/release/rpm/rpm.prerm.sh @@ -1,8 +1,8 @@ -# $1 == 0 for uninstallation. -# $1 == 1 for removing old package during upgrade. - -if [ $1 -eq 0 ] ; then - # Package removal, not upgrade - systemctl --no-reload disable tailscaled.service > /dev/null 2>&1 || : - systemctl stop tailscaled.service > /dev/null 2>&1 || : -fi +# $1 == 0 for uninstallation. +# $1 == 1 for removing old package during upgrade. + +if [ $1 -eq 0 ] ; then + # Package removal, not upgrade + systemctl --no-reload disable tailscaled.service > /dev/null 2>&1 || : + systemctl stop tailscaled.service > /dev/null 2>&1 || : +fi diff --git a/safesocket/safesocket_test.go b/safesocket/safesocket_test.go index 85b317bd6..3f36a1cf6 100644 --- a/safesocket/safesocket_test.go +++ b/safesocket/safesocket_test.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package safesocket - -import "testing" - -func TestLocalTCPPortAndToken(t *testing.T) { - // Just test that it compiles for now (is available on all platforms). - port, token, err := LocalTCPPortAndToken() - t.Logf("got %v, %s, %v", port, token, err) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package safesocket + +import "testing" + +func TestLocalTCPPortAndToken(t *testing.T) { + // Just test that it compiles for now (is available on all platforms). + port, token, err := LocalTCPPortAndToken() + t.Logf("got %v, %s, %v", port, token, err) +} diff --git a/smallzstd/testdata b/smallzstd/testdata index 498b014fd..76640fdc5 100644 --- a/smallzstd/testdata +++ b/smallzstd/testdata @@ -1,14 +1,14 @@ -{"logtail":{"client_time":"2020-07-01T14:49:40.196597018-07:00","server_time":"2020-07-01T21:49:40.198371511Z"},"text":"9.8M/25.6M magicsock: starting endpoint update (periodic)\n"} -{"logtail":{"client_time":"2020-07-01T14:49:40.345925455-07:00","server_time":"2020-07-01T21:49:40.347904717Z"},"text":"9.9M/25.6M netcheck: udp=true v6=false mapvarydest=false hair=false v4a=202.188.7.1:41641 derp=2 derpdist=1v4:7ms,2v4:3ms,4v4:18ms\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347155742-07:00","server_time":"2020-07-01T21:49:43.34828658Z"},"text":"9.9M/25.6M control: map response long-poll timed out!\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347539333-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.9M/25.6M control: PollNetMap: context canceled\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347767812-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M control: sendStatus: mapRoutine1: state:authenticated\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347817165-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M blockEngineUpdates(false)\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347989028-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M wgcfg: [SViTM] skipping subnet route\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.349997554-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.3M/25.6M Received error: PollNetMap: context canceled\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.350072606-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.3M/25.6M control: mapRoutine: backoff: 30136 msec\n"} -{"logtail":{"client_time":"2020-07-01T14:49:47.998364646-07:00","server_time":"2020-07-01T21:49:47.999333754Z"},"text":"9.5M/25.6M [W1NbE] - [UcppE] Send handshake init [127.3.3.40:1, 6.1.1.6:37388*, 10.3.2.6:41641]\n"} -{"logtail":{"client_time":"2020-07-01T14:49:47.99881914-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M magicsock: adding connection to derp-1 for [W1NbE]\n"} -{"logtail":{"client_time":"2020-07-01T14:49:47.998904932-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M magicsock: 2 active derp conns: derp-1=cr0s,wr0s derp-2=cr16h0m0s,wr14h38m0s\n"} -{"logtail":{"client_time":"2020-07-01T14:49:47.999045606-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M derphttp.Client.Recv: connecting to derp-1 (nyc)\n"} -{"logtail":{"client_time":"2020-07-01T14:49:48.091104119-07:00","server_time":"2020-07-01T21:49:48.09280535Z"},"text":"9.6M/25.6M magicsock: rx [W1NbE] from 6.1.1.6:37388 (1/3), set as new priority\n"} +{"logtail":{"client_time":"2020-07-01T14:49:40.196597018-07:00","server_time":"2020-07-01T21:49:40.198371511Z"},"text":"9.8M/25.6M magicsock: starting endpoint update (periodic)\n"} +{"logtail":{"client_time":"2020-07-01T14:49:40.345925455-07:00","server_time":"2020-07-01T21:49:40.347904717Z"},"text":"9.9M/25.6M netcheck: udp=true v6=false mapvarydest=false hair=false v4a=202.188.7.1:41641 derp=2 derpdist=1v4:7ms,2v4:3ms,4v4:18ms\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.347155742-07:00","server_time":"2020-07-01T21:49:43.34828658Z"},"text":"9.9M/25.6M control: map response long-poll timed out!\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.347539333-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.9M/25.6M control: PollNetMap: context canceled\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.347767812-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M control: sendStatus: mapRoutine1: state:authenticated\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.347817165-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M blockEngineUpdates(false)\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.347989028-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M wgcfg: [SViTM] skipping subnet route\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.349997554-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.3M/25.6M Received error: PollNetMap: context canceled\n"} +{"logtail":{"client_time":"2020-07-01T14:49:43.350072606-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.3M/25.6M control: mapRoutine: backoff: 30136 msec\n"} +{"logtail":{"client_time":"2020-07-01T14:49:47.998364646-07:00","server_time":"2020-07-01T21:49:47.999333754Z"},"text":"9.5M/25.6M [W1NbE] - [UcppE] Send handshake init [127.3.3.40:1, 6.1.1.6:37388*, 10.3.2.6:41641]\n"} +{"logtail":{"client_time":"2020-07-01T14:49:47.99881914-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M magicsock: adding connection to derp-1 for [W1NbE]\n"} +{"logtail":{"client_time":"2020-07-01T14:49:47.998904932-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M magicsock: 2 active derp conns: derp-1=cr0s,wr0s derp-2=cr16h0m0s,wr14h38m0s\n"} +{"logtail":{"client_time":"2020-07-01T14:49:47.999045606-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M derphttp.Client.Recv: connecting to derp-1 (nyc)\n"} +{"logtail":{"client_time":"2020-07-01T14:49:48.091104119-07:00","server_time":"2020-07-01T21:49:48.09280535Z"},"text":"9.6M/25.6M magicsock: rx [W1NbE] from 6.1.1.6:37388 (1/3), set as new priority\n"} diff --git a/smallzstd/zstd.go b/smallzstd/zstd.go index d91afeb67..1d8085422 100644 --- a/smallzstd/zstd.go +++ b/smallzstd/zstd.go @@ -1,78 +1,78 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package smallzstd produces zstd encoders and decoders optimized for -// low memory usage, at the expense of compression efficiency. -// -// This package is optimized primarily for the memory cost of -// compressing and decompressing data. We reduce this cost in two -// major ways: disable parallelism within the library (i.e. don't use -// multiple CPU cores to decompress), and drop the compression window -// down from the defaults of 4-16MiB, to 8kiB. -// -// Decompressors cost 2x the window size in RAM to run, so by using an -// 8kiB window, we can run ~1000 more decompressors per unit of memory -// than with the defaults. -// -// Depending on context, the benefit is either being able to run more -// decoders (e.g. in our logs processing system), or having a lower -// memory footprint when using compression in network protocols -// (e.g. in tailscaled, which should have a minimal RAM cost). -package smallzstd - -import ( - "io" - - "github.com/klauspost/compress/zstd" -) - -// WindowSize is the window size used for zstd compression. Decoder -// memory usage scales linearly with WindowSize. -const WindowSize = 8 << 10 // 8kiB - -// NewDecoder returns a zstd.Decoder configured for low memory usage, -// at the expense of decompression performance. -func NewDecoder(r io.Reader, options ...zstd.DOption) (*zstd.Decoder, error) { - defaults := []zstd.DOption{ - // Default is GOMAXPROCS, which costs many KiB in stacks. - zstd.WithDecoderConcurrency(1), - // Default is to allocate more upfront for performance. We - // prefer lower memory use and a bit of GC load. - zstd.WithDecoderLowmem(true), - // You might expect to see zstd.WithDecoderMaxMemory - // here. However, it's not terribly safe to use if you're - // doing stateless decoding, because it sets the maximum - // amount of memory the decompressed data can occupy, rather - // than the window size of the zstd stream. This means a very - // compressible piece of data might violate the max memory - // limit here, even if the window size (and thus total memory - // required to decompress the data) is small. - // - // As a result, we don't set a decoder limit here, and rely on - // the encoder below producing "cheap" streams. Callers are - // welcome to set their own max memory setting, if - // contextually there is a clearly correct value (e.g. it's - // known from the upper layer protocol that the decoded data - // can never be more than 1MiB). - } - - return zstd.NewReader(r, append(defaults, options...)...) -} - -// NewEncoder returns a zstd.Encoder configured for low memory usage, -// both during compression and at decompression time, at the expense -// of performance and compression efficiency. -func NewEncoder(w io.Writer, options ...zstd.EOption) (*zstd.Encoder, error) { - defaults := []zstd.EOption{ - // Default is GOMAXPROCS, which costs many KiB in stacks. - zstd.WithEncoderConcurrency(1), - // Default is several MiB, which bloats both encoders and - // their corresponding decoders. - zstd.WithWindowSize(WindowSize), - // Encode zero-length inputs in a way that the `zstd` utility - // can read, because interoperability is handy. - zstd.WithZeroFrames(true), - } - - return zstd.NewWriter(w, append(defaults, options...)...) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package smallzstd produces zstd encoders and decoders optimized for +// low memory usage, at the expense of compression efficiency. +// +// This package is optimized primarily for the memory cost of +// compressing and decompressing data. We reduce this cost in two +// major ways: disable parallelism within the library (i.e. don't use +// multiple CPU cores to decompress), and drop the compression window +// down from the defaults of 4-16MiB, to 8kiB. +// +// Decompressors cost 2x the window size in RAM to run, so by using an +// 8kiB window, we can run ~1000 more decompressors per unit of memory +// than with the defaults. +// +// Depending on context, the benefit is either being able to run more +// decoders (e.g. in our logs processing system), or having a lower +// memory footprint when using compression in network protocols +// (e.g. in tailscaled, which should have a minimal RAM cost). +package smallzstd + +import ( + "io" + + "github.com/klauspost/compress/zstd" +) + +// WindowSize is the window size used for zstd compression. Decoder +// memory usage scales linearly with WindowSize. +const WindowSize = 8 << 10 // 8kiB + +// NewDecoder returns a zstd.Decoder configured for low memory usage, +// at the expense of decompression performance. +func NewDecoder(r io.Reader, options ...zstd.DOption) (*zstd.Decoder, error) { + defaults := []zstd.DOption{ + // Default is GOMAXPROCS, which costs many KiB in stacks. + zstd.WithDecoderConcurrency(1), + // Default is to allocate more upfront for performance. We + // prefer lower memory use and a bit of GC load. + zstd.WithDecoderLowmem(true), + // You might expect to see zstd.WithDecoderMaxMemory + // here. However, it's not terribly safe to use if you're + // doing stateless decoding, because it sets the maximum + // amount of memory the decompressed data can occupy, rather + // than the window size of the zstd stream. This means a very + // compressible piece of data might violate the max memory + // limit here, even if the window size (and thus total memory + // required to decompress the data) is small. + // + // As a result, we don't set a decoder limit here, and rely on + // the encoder below producing "cheap" streams. Callers are + // welcome to set their own max memory setting, if + // contextually there is a clearly correct value (e.g. it's + // known from the upper layer protocol that the decoded data + // can never be more than 1MiB). + } + + return zstd.NewReader(r, append(defaults, options...)...) +} + +// NewEncoder returns a zstd.Encoder configured for low memory usage, +// both during compression and at decompression time, at the expense +// of performance and compression efficiency. +func NewEncoder(w io.Writer, options ...zstd.EOption) (*zstd.Encoder, error) { + defaults := []zstd.EOption{ + // Default is GOMAXPROCS, which costs many KiB in stacks. + zstd.WithEncoderConcurrency(1), + // Default is several MiB, which bloats both encoders and + // their corresponding decoders. + zstd.WithWindowSize(WindowSize), + // Encode zero-length inputs in a way that the `zstd` utility + // can read, because interoperability is handy. + zstd.WithZeroFrames(true), + } + + return zstd.NewWriter(w, append(defaults, options...)...) +} diff --git a/syncs/locked.go b/syncs/locked.go index abde5bca6..d2048665d 100644 --- a/syncs/locked.go +++ b/syncs/locked.go @@ -1,32 +1,32 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syncs - -import ( - "sync" -) - -// AssertLocked panics if m is not locked. -func AssertLocked(m *sync.Mutex) { - if m.TryLock() { - m.Unlock() - panic("mutex is not locked") - } -} - -// AssertRLocked panics if rw is not locked for reading or writing. -func AssertRLocked(rw *sync.RWMutex) { - if rw.TryLock() { - rw.Unlock() - panic("mutex is not locked") - } -} - -// AssertWLocked panics if rw is not locked for writing. -func AssertWLocked(rw *sync.RWMutex) { - if rw.TryRLock() { - rw.RUnlock() - panic("mutex is not rlocked") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import ( + "sync" +) + +// AssertLocked panics if m is not locked. +func AssertLocked(m *sync.Mutex) { + if m.TryLock() { + m.Unlock() + panic("mutex is not locked") + } +} + +// AssertRLocked panics if rw is not locked for reading or writing. +func AssertRLocked(rw *sync.RWMutex) { + if rw.TryLock() { + rw.Unlock() + panic("mutex is not locked") + } +} + +// AssertWLocked panics if rw is not locked for writing. +func AssertWLocked(rw *sync.RWMutex) { + if rw.TryRLock() { + rw.RUnlock() + panic("mutex is not rlocked") + } +} diff --git a/syncs/locked_test.go b/syncs/locked_test.go index 44877be50..90b36e832 100644 --- a/syncs/locked_test.go +++ b/syncs/locked_test.go @@ -1,120 +1,120 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build go1.13 && !go1.19 - -package syncs - -import ( - "sync" - "testing" - "time" -) - -func wantPanic(t *testing.T, fn func()) { - t.Helper() - defer func() { - recover() - }() - fn() - t.Fatal("failed to panic") -} - -func TestAssertLocked(t *testing.T) { - m := new(sync.Mutex) - wantPanic(t, func() { AssertLocked(m) }) - m.Lock() - AssertLocked(m) - m.Unlock() - wantPanic(t, func() { AssertLocked(m) }) - // Test correct handling of mutex with waiter. - m.Lock() - AssertLocked(m) - go func() { - m.Lock() - m.Unlock() - }() - // Give the goroutine above a few moments to get started. - // The test will pass whether or not we win the race, - // but we want to run sometimes, to get the test coverage. - time.Sleep(10 * time.Millisecond) - AssertLocked(m) -} - -func TestAssertWLocked(t *testing.T) { - m := new(sync.RWMutex) - wantPanic(t, func() { AssertWLocked(m) }) - m.Lock() - AssertWLocked(m) - m.Unlock() - wantPanic(t, func() { AssertWLocked(m) }) - // Test correct handling of mutex with waiter. - m.Lock() - AssertWLocked(m) - go func() { - m.Lock() - m.Unlock() - }() - // Give the goroutine above a few moments to get started. - // The test will pass whether or not we win the race, - // but we want to run sometimes, to get the test coverage. - time.Sleep(10 * time.Millisecond) - AssertWLocked(m) -} - -func TestAssertRLocked(t *testing.T) { - m := new(sync.RWMutex) - wantPanic(t, func() { AssertRLocked(m) }) - - m.Lock() - AssertRLocked(m) - m.Unlock() - - m.RLock() - AssertRLocked(m) - m.RUnlock() - - wantPanic(t, func() { AssertRLocked(m) }) - - // Test correct handling of mutex with waiter. - m.RLock() - AssertRLocked(m) - go func() { - m.RLock() - m.RUnlock() - }() - // Give the goroutine above a few moments to get started. - // The test will pass whether or not we win the race, - // but we want to run sometimes, to get the test coverage. - time.Sleep(10 * time.Millisecond) - AssertRLocked(m) - m.RUnlock() - - // Test correct handling of rlock with write waiter. - m.RLock() - AssertRLocked(m) - go func() { - m.Lock() - m.Unlock() - }() - // Give the goroutine above a few moments to get started. - // The test will pass whether or not we win the race, - // but we want to run sometimes, to get the test coverage. - time.Sleep(10 * time.Millisecond) - AssertRLocked(m) - m.RUnlock() - - // Test correct handling of rlock with other rlocks. - // This is a bit racy, but losing the race hurts nothing, - // and winning the race means correct test coverage. - m.RLock() - AssertRLocked(m) - go func() { - m.RLock() - time.Sleep(10 * time.Millisecond) - m.RUnlock() - }() - time.Sleep(5 * time.Millisecond) - AssertRLocked(m) - m.RUnlock() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build go1.13 && !go1.19 + +package syncs + +import ( + "sync" + "testing" + "time" +) + +func wantPanic(t *testing.T, fn func()) { + t.Helper() + defer func() { + recover() + }() + fn() + t.Fatal("failed to panic") +} + +func TestAssertLocked(t *testing.T) { + m := new(sync.Mutex) + wantPanic(t, func() { AssertLocked(m) }) + m.Lock() + AssertLocked(m) + m.Unlock() + wantPanic(t, func() { AssertLocked(m) }) + // Test correct handling of mutex with waiter. + m.Lock() + AssertLocked(m) + go func() { + m.Lock() + m.Unlock() + }() + // Give the goroutine above a few moments to get started. + // The test will pass whether or not we win the race, + // but we want to run sometimes, to get the test coverage. + time.Sleep(10 * time.Millisecond) + AssertLocked(m) +} + +func TestAssertWLocked(t *testing.T) { + m := new(sync.RWMutex) + wantPanic(t, func() { AssertWLocked(m) }) + m.Lock() + AssertWLocked(m) + m.Unlock() + wantPanic(t, func() { AssertWLocked(m) }) + // Test correct handling of mutex with waiter. + m.Lock() + AssertWLocked(m) + go func() { + m.Lock() + m.Unlock() + }() + // Give the goroutine above a few moments to get started. + // The test will pass whether or not we win the race, + // but we want to run sometimes, to get the test coverage. + time.Sleep(10 * time.Millisecond) + AssertWLocked(m) +} + +func TestAssertRLocked(t *testing.T) { + m := new(sync.RWMutex) + wantPanic(t, func() { AssertRLocked(m) }) + + m.Lock() + AssertRLocked(m) + m.Unlock() + + m.RLock() + AssertRLocked(m) + m.RUnlock() + + wantPanic(t, func() { AssertRLocked(m) }) + + // Test correct handling of mutex with waiter. + m.RLock() + AssertRLocked(m) + go func() { + m.RLock() + m.RUnlock() + }() + // Give the goroutine above a few moments to get started. + // The test will pass whether or not we win the race, + // but we want to run sometimes, to get the test coverage. + time.Sleep(10 * time.Millisecond) + AssertRLocked(m) + m.RUnlock() + + // Test correct handling of rlock with write waiter. + m.RLock() + AssertRLocked(m) + go func() { + m.Lock() + m.Unlock() + }() + // Give the goroutine above a few moments to get started. + // The test will pass whether or not we win the race, + // but we want to run sometimes, to get the test coverage. + time.Sleep(10 * time.Millisecond) + AssertRLocked(m) + m.RUnlock() + + // Test correct handling of rlock with other rlocks. + // This is a bit racy, but losing the race hurts nothing, + // and winning the race means correct test coverage. + m.RLock() + AssertRLocked(m) + go func() { + m.RLock() + time.Sleep(10 * time.Millisecond) + m.RUnlock() + }() + time.Sleep(5 * time.Millisecond) + AssertRLocked(m) + m.RUnlock() +} diff --git a/syncs/shardedmap.go b/syncs/shardedmap.go index 906de3ade..12edf5bfc 100644 --- a/syncs/shardedmap.go +++ b/syncs/shardedmap.go @@ -1,138 +1,138 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syncs - -import ( - "sync" - - "golang.org/x/sys/cpu" -) - -// ShardedMap is a synchronized map[K]V, internally sharded by a user-defined -// K-sharding function. -// -// The zero value is not safe for use; use NewShardedMap. -type ShardedMap[K comparable, V any] struct { - shardFunc func(K) int - shards []mapShard[K, V] -} - -type mapShard[K comparable, V any] struct { - mu sync.Mutex - m map[K]V - _ cpu.CacheLinePad // avoid false sharing of neighboring shards' mutexes -} - -// NewShardedMap returns a new ShardedMap with the given number of shards and -// sharding function. -// -// The shard func must return a integer in the range [0, shards) purely -// deterministically based on the provided K. -func NewShardedMap[K comparable, V any](shards int, shard func(K) int) *ShardedMap[K, V] { - m := &ShardedMap[K, V]{ - shardFunc: shard, - shards: make([]mapShard[K, V], shards), - } - for i := range m.shards { - m.shards[i].m = make(map[K]V) - } - return m -} - -func (m *ShardedMap[K, V]) shard(key K) *mapShard[K, V] { - return &m.shards[m.shardFunc(key)] -} - -// GetOk returns m[key] and whether it was present. -func (m *ShardedMap[K, V]) GetOk(key K) (value V, ok bool) { - shard := m.shard(key) - shard.mu.Lock() - defer shard.mu.Unlock() - value, ok = shard.m[key] - return -} - -// Get returns m[key] or the zero value of V if key is not present. -func (m *ShardedMap[K, V]) Get(key K) (value V) { - value, _ = m.GetOk(key) - return -} - -// Mutate atomically mutates m[k] by calling mutator. -// -// The mutator function is called with the old value (or its zero value) and -// whether it existed in the map and it returns the new value and whether it -// should be set in the map (true) or deleted from the map (false). -// -// It returns the change in size of the map as a result of the mutation, one of -// -1 (delete), 0 (change), or 1 (addition). -func (m *ShardedMap[K, V]) Mutate(key K, mutator func(oldValue V, oldValueExisted bool) (newValue V, keep bool)) (sizeDelta int) { - shard := m.shard(key) - shard.mu.Lock() - defer shard.mu.Unlock() - oldV, oldOK := shard.m[key] - newV, newOK := mutator(oldV, oldOK) - if newOK { - shard.m[key] = newV - if oldOK { - return 0 - } - return 1 - } - delete(shard.m, key) - if oldOK { - return -1 - } - return 0 -} - -// Set sets m[key] = value. -// -// present in m). -func (m *ShardedMap[K, V]) Set(key K, value V) (grew bool) { - shard := m.shard(key) - shard.mu.Lock() - defer shard.mu.Unlock() - s0 := len(shard.m) - shard.m[key] = value - return len(shard.m) > s0 -} - -// Delete removes key from m. -// -// It reports whether the map size shrunk (that is, whether key was present in -// the map). -func (m *ShardedMap[K, V]) Delete(key K) (shrunk bool) { - shard := m.shard(key) - shard.mu.Lock() - defer shard.mu.Unlock() - s0 := len(shard.m) - delete(shard.m, key) - return len(shard.m) < s0 -} - -// Contains reports whether m contains key. -func (m *ShardedMap[K, V]) Contains(key K) bool { - shard := m.shard(key) - shard.mu.Lock() - defer shard.mu.Unlock() - _, ok := shard.m[key] - return ok -} - -// Len returns the number of elements in m. -// -// It does so by locking shards one at a time, so it's not particularly cheap, -// nor does it give a consistent snapshot of the map. It's mostly intended for -// metrics or testing. -func (m *ShardedMap[K, V]) Len() int { - n := 0 - for i := range m.shards { - shard := &m.shards[i] - shard.mu.Lock() - n += len(shard.m) - shard.mu.Unlock() - } - return n -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import ( + "sync" + + "golang.org/x/sys/cpu" +) + +// ShardedMap is a synchronized map[K]V, internally sharded by a user-defined +// K-sharding function. +// +// The zero value is not safe for use; use NewShardedMap. +type ShardedMap[K comparable, V any] struct { + shardFunc func(K) int + shards []mapShard[K, V] +} + +type mapShard[K comparable, V any] struct { + mu sync.Mutex + m map[K]V + _ cpu.CacheLinePad // avoid false sharing of neighboring shards' mutexes +} + +// NewShardedMap returns a new ShardedMap with the given number of shards and +// sharding function. +// +// The shard func must return a integer in the range [0, shards) purely +// deterministically based on the provided K. +func NewShardedMap[K comparable, V any](shards int, shard func(K) int) *ShardedMap[K, V] { + m := &ShardedMap[K, V]{ + shardFunc: shard, + shards: make([]mapShard[K, V], shards), + } + for i := range m.shards { + m.shards[i].m = make(map[K]V) + } + return m +} + +func (m *ShardedMap[K, V]) shard(key K) *mapShard[K, V] { + return &m.shards[m.shardFunc(key)] +} + +// GetOk returns m[key] and whether it was present. +func (m *ShardedMap[K, V]) GetOk(key K) (value V, ok bool) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + value, ok = shard.m[key] + return +} + +// Get returns m[key] or the zero value of V if key is not present. +func (m *ShardedMap[K, V]) Get(key K) (value V) { + value, _ = m.GetOk(key) + return +} + +// Mutate atomically mutates m[k] by calling mutator. +// +// The mutator function is called with the old value (or its zero value) and +// whether it existed in the map and it returns the new value and whether it +// should be set in the map (true) or deleted from the map (false). +// +// It returns the change in size of the map as a result of the mutation, one of +// -1 (delete), 0 (change), or 1 (addition). +func (m *ShardedMap[K, V]) Mutate(key K, mutator func(oldValue V, oldValueExisted bool) (newValue V, keep bool)) (sizeDelta int) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + oldV, oldOK := shard.m[key] + newV, newOK := mutator(oldV, oldOK) + if newOK { + shard.m[key] = newV + if oldOK { + return 0 + } + return 1 + } + delete(shard.m, key) + if oldOK { + return -1 + } + return 0 +} + +// Set sets m[key] = value. +// +// present in m). +func (m *ShardedMap[K, V]) Set(key K, value V) (grew bool) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + s0 := len(shard.m) + shard.m[key] = value + return len(shard.m) > s0 +} + +// Delete removes key from m. +// +// It reports whether the map size shrunk (that is, whether key was present in +// the map). +func (m *ShardedMap[K, V]) Delete(key K) (shrunk bool) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + s0 := len(shard.m) + delete(shard.m, key) + return len(shard.m) < s0 +} + +// Contains reports whether m contains key. +func (m *ShardedMap[K, V]) Contains(key K) bool { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + _, ok := shard.m[key] + return ok +} + +// Len returns the number of elements in m. +// +// It does so by locking shards one at a time, so it's not particularly cheap, +// nor does it give a consistent snapshot of the map. It's mostly intended for +// metrics or testing. +func (m *ShardedMap[K, V]) Len() int { + n := 0 + for i := range m.shards { + shard := &m.shards[i] + shard.mu.Lock() + n += len(shard.m) + shard.mu.Unlock() + } + return n +} diff --git a/syncs/shardedmap_test.go b/syncs/shardedmap_test.go index 170201c0a..993ffdff8 100644 --- a/syncs/shardedmap_test.go +++ b/syncs/shardedmap_test.go @@ -1,81 +1,81 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syncs - -import "testing" - -func TestShardedMap(t *testing.T) { - m := NewShardedMap[int, string](16, func(i int) int { return i % 16 }) - - if m.Contains(1) { - t.Errorf("got contains; want !contains") - } - if !m.Set(1, "one") { - t.Errorf("got !set; want set") - } - if m.Set(1, "one") { - t.Errorf("got set; want !set") - } - if !m.Contains(1) { - t.Errorf("got !contains; want contains") - } - if g, w := m.Get(1), "one"; g != w { - t.Errorf("got %q; want %q", g, w) - } - if _, ok := m.GetOk(1); !ok { - t.Errorf("got ok; want !ok") - } - if _, ok := m.GetOk(2); ok { - t.Errorf("got ok; want !ok") - } - if g, w := m.Len(), 1; g != w { - t.Errorf("got Len %v; want %v", g, w) - } - if m.Delete(2) { - t.Errorf("got deleted; want !deleted") - } - if !m.Delete(1) { - t.Errorf("got !deleted; want deleted") - } - if g, w := m.Len(), 0; g != w { - t.Errorf("got Len %v; want %v", g, w) - } - - // Mutation adding an entry. - if v := m.Mutate(1, func(was string, ok bool) (string, bool) { - if ok { - t.Fatal("was okay") - } - return "ONE", true - }); v != 1 { - t.Errorf("Mutate = %v; want 1", v) - } - if g, w := m.Get(1), "ONE"; g != w { - t.Errorf("got %q; want %q", g, w) - } - // Mutation changing an entry. - if v := m.Mutate(1, func(was string, ok bool) (string, bool) { - if !ok { - t.Fatal("wasn't okay") - } - return was + "-" + was, true - }); v != 0 { - t.Errorf("Mutate = %v; want 0", v) - } - if g, w := m.Get(1), "ONE-ONE"; g != w { - t.Errorf("got %q; want %q", g, w) - } - // Mutation removing an entry. - if v := m.Mutate(1, func(was string, ok bool) (string, bool) { - if !ok { - t.Fatal("wasn't okay") - } - return "", false - }); v != -1 { - t.Errorf("Mutate = %v; want -1", v) - } - if g, w := m.Get(1), ""; g != w { - t.Errorf("got %q; want %q", g, w) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import "testing" + +func TestShardedMap(t *testing.T) { + m := NewShardedMap[int, string](16, func(i int) int { return i % 16 }) + + if m.Contains(1) { + t.Errorf("got contains; want !contains") + } + if !m.Set(1, "one") { + t.Errorf("got !set; want set") + } + if m.Set(1, "one") { + t.Errorf("got set; want !set") + } + if !m.Contains(1) { + t.Errorf("got !contains; want contains") + } + if g, w := m.Get(1), "one"; g != w { + t.Errorf("got %q; want %q", g, w) + } + if _, ok := m.GetOk(1); !ok { + t.Errorf("got ok; want !ok") + } + if _, ok := m.GetOk(2); ok { + t.Errorf("got ok; want !ok") + } + if g, w := m.Len(), 1; g != w { + t.Errorf("got Len %v; want %v", g, w) + } + if m.Delete(2) { + t.Errorf("got deleted; want !deleted") + } + if !m.Delete(1) { + t.Errorf("got !deleted; want deleted") + } + if g, w := m.Len(), 0; g != w { + t.Errorf("got Len %v; want %v", g, w) + } + + // Mutation adding an entry. + if v := m.Mutate(1, func(was string, ok bool) (string, bool) { + if ok { + t.Fatal("was okay") + } + return "ONE", true + }); v != 1 { + t.Errorf("Mutate = %v; want 1", v) + } + if g, w := m.Get(1), "ONE"; g != w { + t.Errorf("got %q; want %q", g, w) + } + // Mutation changing an entry. + if v := m.Mutate(1, func(was string, ok bool) (string, bool) { + if !ok { + t.Fatal("wasn't okay") + } + return was + "-" + was, true + }); v != 0 { + t.Errorf("Mutate = %v; want 0", v) + } + if g, w := m.Get(1), "ONE-ONE"; g != w { + t.Errorf("got %q; want %q", g, w) + } + // Mutation removing an entry. + if v := m.Mutate(1, func(was string, ok bool) (string, bool) { + if !ok { + t.Fatal("wasn't okay") + } + return "", false + }); v != -1 { + t.Errorf("Mutate = %v; want -1", v) + } + if g, w := m.Get(1), ""; g != w { + t.Errorf("got %q; want %q", g, w) + } +} diff --git a/tailcfg/proto_port_range.go b/tailcfg/proto_port_range.go index 0bb7e388e..f65c58804 100644 --- a/tailcfg/proto_port_range.go +++ b/tailcfg/proto_port_range.go @@ -1,187 +1,187 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tailcfg - -import ( - "errors" - "fmt" - "strconv" - "strings" - - "tailscale.com/types/ipproto" - "tailscale.com/util/vizerror" -) - -var ( - errEmptyProtocol = errors.New("empty protocol") - errEmptyString = errors.New("empty string") -) - -// ProtoPortRange is used to encode "proto:port" format. -// The following formats are supported: -// -// "*" allows all TCP, UDP and ICMP traffic on all ports. -// "" allows all TCP, UDP and ICMP traffic on the specified ports. -// "proto:*" allows traffic of the specified proto on all ports. -// "proto:" allows traffic of the specified proto on the specified port. -// -// Ports are either a single port number or a range of ports (e.g. "80-90"). -// String named protocols support names that ipproto.Proto accepts. -type ProtoPortRange struct { - // Proto is the IP protocol number. - // If Proto is 0, it means TCP+UDP+ICMP(4+6). - Proto int - Ports PortRange -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. See -// ProtoPortRange for the format. -func (ppr *ProtoPortRange) UnmarshalText(text []byte) error { - ppr2, err := parseProtoPortRange(string(text)) - if err != nil { - return err - } - *ppr = *ppr2 - return nil -} - -// MarshalText implements the encoding.TextMarshaler interface. See -// ProtoPortRange for the format. -func (ppr *ProtoPortRange) MarshalText() ([]byte, error) { - if ppr.Proto == 0 && ppr.Ports == (PortRange{}) { - return []byte{}, nil - } - return []byte(ppr.String()), nil -} - -// String implements the stringer interface. See ProtoPortRange for the -// format. -func (ppr ProtoPortRange) String() string { - if ppr.Proto == 0 { - if ppr.Ports == PortRangeAny { - return "*" - } - } - var buf strings.Builder - if ppr.Proto != 0 { - // Proto.MarshalText is infallible. - text, _ := ipproto.Proto(ppr.Proto).MarshalText() - buf.Write(text) - buf.Write([]byte(":")) - } - pr := ppr.Ports - if pr.First == pr.Last { - fmt.Fprintf(&buf, "%d", pr.First) - } else if pr == PortRangeAny { - buf.WriteByte('*') - } else { - fmt.Fprintf(&buf, "%d-%d", pr.First, pr.Last) - } - return buf.String() -} - -// ParseProtoPortRanges parses a slice of IP port range fields. -func ParseProtoPortRanges(ips []string) ([]ProtoPortRange, error) { - var out []ProtoPortRange - for _, p := range ips { - ppr, err := parseProtoPortRange(p) - if err != nil { - return nil, err - } - out = append(out, *ppr) - } - return out, nil -} - -func parseProtoPortRange(ipProtoPort string) (*ProtoPortRange, error) { - if ipProtoPort == "" { - return nil, errEmptyString - } - if ipProtoPort == "*" { - return &ProtoPortRange{Ports: PortRangeAny}, nil - } - if !strings.Contains(ipProtoPort, ":") { - ipProtoPort = "*:" + ipProtoPort - } - protoStr, portRange, err := parseHostPortRange(ipProtoPort) - if err != nil { - return nil, err - } - if protoStr == "" { - return nil, errEmptyProtocol - } - - ppr := &ProtoPortRange{ - Ports: portRange, - } - if protoStr == "*" { - return ppr, nil - } - var ipProto ipproto.Proto - if err := ipProto.UnmarshalText([]byte(protoStr)); err != nil { - return nil, err - } - ppr.Proto = int(ipProto) - return ppr, nil -} - -// parseHostPortRange parses hostport as HOST:PORTS where HOST is -// returned unchanged and PORTS is is either "*" or PORTLOW-PORTHIGH ranges. -func parseHostPortRange(hostport string) (host string, ports PortRange, err error) { - hostport = strings.ToLower(hostport) - colon := strings.LastIndexByte(hostport, ':') - if colon < 0 { - return "", ports, vizerror.New("hostport must contain a colon (\":\")") - } - host = hostport[:colon] - portlist := hostport[colon+1:] - - if strings.Contains(host, ",") { - return "", ports, vizerror.New("host cannot contain a comma (\",\")") - } - - if portlist == "*" { - // Special case: permit hostname:* as a port wildcard. - return host, PortRangeAny, nil - } - - if len(portlist) == 0 { - return "", ports, vizerror.Errorf("invalid port list: %#v", portlist) - } - - if strings.Count(portlist, "-") > 1 { - return "", ports, vizerror.Errorf("port range %#v: too many dashes(-)", portlist) - } - - firstStr, lastStr, isRange := strings.Cut(portlist, "-") - - var first, last uint64 - first, err = strconv.ParseUint(firstStr, 10, 16) - if err != nil { - return "", ports, vizerror.Errorf("port range %#v: invalid first integer", portlist) - } - - if isRange { - last, err = strconv.ParseUint(lastStr, 10, 16) - if err != nil { - return "", ports, vizerror.Errorf("port range %#v: invalid last integer", portlist) - } - } else { - last = first - } - - if first == 0 { - return "", ports, vizerror.Errorf("port range %#v: first port must be >0, or use '*' for wildcard", portlist) - } - - if first > last { - return "", ports, vizerror.Errorf("port range %#v: first port must be >= last port", portlist) - } - - return host, newPortRange(uint16(first), uint16(last)), nil -} - -func newPortRange(first, last uint16) PortRange { - return PortRange{First: first, Last: last} -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailcfg + +import ( + "errors" + "fmt" + "strconv" + "strings" + + "tailscale.com/types/ipproto" + "tailscale.com/util/vizerror" +) + +var ( + errEmptyProtocol = errors.New("empty protocol") + errEmptyString = errors.New("empty string") +) + +// ProtoPortRange is used to encode "proto:port" format. +// The following formats are supported: +// +// "*" allows all TCP, UDP and ICMP traffic on all ports. +// "" allows all TCP, UDP and ICMP traffic on the specified ports. +// "proto:*" allows traffic of the specified proto on all ports. +// "proto:" allows traffic of the specified proto on the specified port. +// +// Ports are either a single port number or a range of ports (e.g. "80-90"). +// String named protocols support names that ipproto.Proto accepts. +type ProtoPortRange struct { + // Proto is the IP protocol number. + // If Proto is 0, it means TCP+UDP+ICMP(4+6). + Proto int + Ports PortRange +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. See +// ProtoPortRange for the format. +func (ppr *ProtoPortRange) UnmarshalText(text []byte) error { + ppr2, err := parseProtoPortRange(string(text)) + if err != nil { + return err + } + *ppr = *ppr2 + return nil +} + +// MarshalText implements the encoding.TextMarshaler interface. See +// ProtoPortRange for the format. +func (ppr *ProtoPortRange) MarshalText() ([]byte, error) { + if ppr.Proto == 0 && ppr.Ports == (PortRange{}) { + return []byte{}, nil + } + return []byte(ppr.String()), nil +} + +// String implements the stringer interface. See ProtoPortRange for the +// format. +func (ppr ProtoPortRange) String() string { + if ppr.Proto == 0 { + if ppr.Ports == PortRangeAny { + return "*" + } + } + var buf strings.Builder + if ppr.Proto != 0 { + // Proto.MarshalText is infallible. + text, _ := ipproto.Proto(ppr.Proto).MarshalText() + buf.Write(text) + buf.Write([]byte(":")) + } + pr := ppr.Ports + if pr.First == pr.Last { + fmt.Fprintf(&buf, "%d", pr.First) + } else if pr == PortRangeAny { + buf.WriteByte('*') + } else { + fmt.Fprintf(&buf, "%d-%d", pr.First, pr.Last) + } + return buf.String() +} + +// ParseProtoPortRanges parses a slice of IP port range fields. +func ParseProtoPortRanges(ips []string) ([]ProtoPortRange, error) { + var out []ProtoPortRange + for _, p := range ips { + ppr, err := parseProtoPortRange(p) + if err != nil { + return nil, err + } + out = append(out, *ppr) + } + return out, nil +} + +func parseProtoPortRange(ipProtoPort string) (*ProtoPortRange, error) { + if ipProtoPort == "" { + return nil, errEmptyString + } + if ipProtoPort == "*" { + return &ProtoPortRange{Ports: PortRangeAny}, nil + } + if !strings.Contains(ipProtoPort, ":") { + ipProtoPort = "*:" + ipProtoPort + } + protoStr, portRange, err := parseHostPortRange(ipProtoPort) + if err != nil { + return nil, err + } + if protoStr == "" { + return nil, errEmptyProtocol + } + + ppr := &ProtoPortRange{ + Ports: portRange, + } + if protoStr == "*" { + return ppr, nil + } + var ipProto ipproto.Proto + if err := ipProto.UnmarshalText([]byte(protoStr)); err != nil { + return nil, err + } + ppr.Proto = int(ipProto) + return ppr, nil +} + +// parseHostPortRange parses hostport as HOST:PORTS where HOST is +// returned unchanged and PORTS is is either "*" or PORTLOW-PORTHIGH ranges. +func parseHostPortRange(hostport string) (host string, ports PortRange, err error) { + hostport = strings.ToLower(hostport) + colon := strings.LastIndexByte(hostport, ':') + if colon < 0 { + return "", ports, vizerror.New("hostport must contain a colon (\":\")") + } + host = hostport[:colon] + portlist := hostport[colon+1:] + + if strings.Contains(host, ",") { + return "", ports, vizerror.New("host cannot contain a comma (\",\")") + } + + if portlist == "*" { + // Special case: permit hostname:* as a port wildcard. + return host, PortRangeAny, nil + } + + if len(portlist) == 0 { + return "", ports, vizerror.Errorf("invalid port list: %#v", portlist) + } + + if strings.Count(portlist, "-") > 1 { + return "", ports, vizerror.Errorf("port range %#v: too many dashes(-)", portlist) + } + + firstStr, lastStr, isRange := strings.Cut(portlist, "-") + + var first, last uint64 + first, err = strconv.ParseUint(firstStr, 10, 16) + if err != nil { + return "", ports, vizerror.Errorf("port range %#v: invalid first integer", portlist) + } + + if isRange { + last, err = strconv.ParseUint(lastStr, 10, 16) + if err != nil { + return "", ports, vizerror.Errorf("port range %#v: invalid last integer", portlist) + } + } else { + last = first + } + + if first == 0 { + return "", ports, vizerror.Errorf("port range %#v: first port must be >0, or use '*' for wildcard", portlist) + } + + if first > last { + return "", ports, vizerror.Errorf("port range %#v: first port must be >= last port", portlist) + } + + return host, newPortRange(uint16(first), uint16(last)), nil +} + +func newPortRange(first, last uint16) PortRange { + return PortRange{First: first, Last: last} +} diff --git a/tailcfg/proto_port_range_test.go b/tailcfg/proto_port_range_test.go index 31b282641..59ccc9be4 100644 --- a/tailcfg/proto_port_range_test.go +++ b/tailcfg/proto_port_range_test.go @@ -1,131 +1,131 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tailcfg - -import ( - "encoding" - "testing" - - "tailscale.com/types/ipproto" - "tailscale.com/util/vizerror" -) - -var _ encoding.TextUnmarshaler = (*ProtoPortRange)(nil) - -func TestProtoPortRangeParsing(t *testing.T) { - pr := func(s, e uint16) PortRange { - return PortRange{First: s, Last: e} - } - tests := []struct { - in string - out ProtoPortRange - err error - }{ - {in: "tcp:80", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: pr(80, 80)}}, - {in: "80", out: ProtoPortRange{Ports: pr(80, 80)}}, - {in: "*", out: ProtoPortRange{Ports: PortRangeAny}}, - {in: "*:*", out: ProtoPortRange{Ports: PortRangeAny}}, - {in: "tcp:*", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: PortRangeAny}}, - { - in: "tcp:", - err: vizerror.Errorf("invalid port list: %#v", ""), - }, - { - in: ":80", - err: errEmptyProtocol, - }, - { - in: "", - err: errEmptyString, - }, - } - - for _, tc := range tests { - t.Run(tc.in, func(t *testing.T) { - var ppr ProtoPortRange - err := ppr.UnmarshalText([]byte(tc.in)) - if tc.err != err { - if err == nil || tc.err.Error() != err.Error() { - t.Fatalf("want err=%v, got %v", tc.err, err) - } - } - if ppr != tc.out { - t.Fatalf("got %v; want %v", ppr, tc.out) - } - }) - } -} - -func TestProtoPortRangeString(t *testing.T) { - tests := []struct { - input ProtoPortRange - want string - }{ - {ProtoPortRange{}, "0"}, - - // Zero protocol. - {ProtoPortRange{Ports: PortRangeAny}, "*"}, - {ProtoPortRange{Ports: PortRange{23, 23}}, "23"}, - {ProtoPortRange{Ports: PortRange{80, 120}}, "80-120"}, - - // Non-zero unnamed protocol. - {ProtoPortRange{Proto: 100, Ports: PortRange{80, 80}}, "100:80"}, - {ProtoPortRange{Proto: 200, Ports: PortRange{101, 105}}, "200:101-105"}, - - // Non-zero named protocol. - {ProtoPortRange{Proto: 1, Ports: PortRangeAny}, "icmp:*"}, - {ProtoPortRange{Proto: 2, Ports: PortRangeAny}, "igmp:*"}, - {ProtoPortRange{Proto: 6, Ports: PortRange{10, 13}}, "tcp:10-13"}, - {ProtoPortRange{Proto: 17, Ports: PortRangeAny}, "udp:*"}, - {ProtoPortRange{Proto: 0x84, Ports: PortRange{999, 999}}, "sctp:999"}, - {ProtoPortRange{Proto: 0x3a, Ports: PortRangeAny}, "ipv6-icmp:*"}, - {ProtoPortRange{Proto: 0x21, Ports: PortRangeAny}, "dccp:*"}, - {ProtoPortRange{Proto: 0x2f, Ports: PortRangeAny}, "gre:*"}, - } - for _, tc := range tests { - if got := tc.input.String(); got != tc.want { - t.Errorf("String for %v: got %q, want %q", tc.input, got, tc.want) - } - } -} - -func TestProtoPortRangeRoundTrip(t *testing.T) { - tests := []struct { - input ProtoPortRange - text string - }{ - {ProtoPortRange{Ports: PortRangeAny}, "*"}, - {ProtoPortRange{Ports: PortRange{23, 23}}, "23"}, - {ProtoPortRange{Ports: PortRange{80, 120}}, "80-120"}, - {ProtoPortRange{Proto: 100, Ports: PortRange{80, 80}}, "100:80"}, - {ProtoPortRange{Proto: 200, Ports: PortRange{101, 105}}, "200:101-105"}, - {ProtoPortRange{Proto: 1, Ports: PortRangeAny}, "icmp:*"}, - {ProtoPortRange{Proto: 2, Ports: PortRangeAny}, "igmp:*"}, - {ProtoPortRange{Proto: 6, Ports: PortRange{10, 13}}, "tcp:10-13"}, - {ProtoPortRange{Proto: 17, Ports: PortRangeAny}, "udp:*"}, - {ProtoPortRange{Proto: 0x84, Ports: PortRange{999, 999}}, "sctp:999"}, - {ProtoPortRange{Proto: 0x3a, Ports: PortRangeAny}, "ipv6-icmp:*"}, - {ProtoPortRange{Proto: 0x21, Ports: PortRangeAny}, "dccp:*"}, - {ProtoPortRange{Proto: 0x2f, Ports: PortRangeAny}, "gre:*"}, - } - - for _, tc := range tests { - out, err := tc.input.MarshalText() - if err != nil { - t.Errorf("MarshalText for %v: %v", tc.input, err) - continue - } - if got := string(out); got != tc.text { - t.Errorf("MarshalText for %#v: got %q, want %q", tc.input, got, tc.text) - } - var ppr ProtoPortRange - if err := ppr.UnmarshalText(out); err != nil { - t.Errorf("UnmarshalText for %q: err=%v", tc.text, err) - continue - } - if ppr != tc.input { - t.Errorf("round trip error for %q: got %v, want %#v", tc.text, ppr, tc.input) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailcfg + +import ( + "encoding" + "testing" + + "tailscale.com/types/ipproto" + "tailscale.com/util/vizerror" +) + +var _ encoding.TextUnmarshaler = (*ProtoPortRange)(nil) + +func TestProtoPortRangeParsing(t *testing.T) { + pr := func(s, e uint16) PortRange { + return PortRange{First: s, Last: e} + } + tests := []struct { + in string + out ProtoPortRange + err error + }{ + {in: "tcp:80", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: pr(80, 80)}}, + {in: "80", out: ProtoPortRange{Ports: pr(80, 80)}}, + {in: "*", out: ProtoPortRange{Ports: PortRangeAny}}, + {in: "*:*", out: ProtoPortRange{Ports: PortRangeAny}}, + {in: "tcp:*", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: PortRangeAny}}, + { + in: "tcp:", + err: vizerror.Errorf("invalid port list: %#v", ""), + }, + { + in: ":80", + err: errEmptyProtocol, + }, + { + in: "", + err: errEmptyString, + }, + } + + for _, tc := range tests { + t.Run(tc.in, func(t *testing.T) { + var ppr ProtoPortRange + err := ppr.UnmarshalText([]byte(tc.in)) + if tc.err != err { + if err == nil || tc.err.Error() != err.Error() { + t.Fatalf("want err=%v, got %v", tc.err, err) + } + } + if ppr != tc.out { + t.Fatalf("got %v; want %v", ppr, tc.out) + } + }) + } +} + +func TestProtoPortRangeString(t *testing.T) { + tests := []struct { + input ProtoPortRange + want string + }{ + {ProtoPortRange{}, "0"}, + + // Zero protocol. + {ProtoPortRange{Ports: PortRangeAny}, "*"}, + {ProtoPortRange{Ports: PortRange{23, 23}}, "23"}, + {ProtoPortRange{Ports: PortRange{80, 120}}, "80-120"}, + + // Non-zero unnamed protocol. + {ProtoPortRange{Proto: 100, Ports: PortRange{80, 80}}, "100:80"}, + {ProtoPortRange{Proto: 200, Ports: PortRange{101, 105}}, "200:101-105"}, + + // Non-zero named protocol. + {ProtoPortRange{Proto: 1, Ports: PortRangeAny}, "icmp:*"}, + {ProtoPortRange{Proto: 2, Ports: PortRangeAny}, "igmp:*"}, + {ProtoPortRange{Proto: 6, Ports: PortRange{10, 13}}, "tcp:10-13"}, + {ProtoPortRange{Proto: 17, Ports: PortRangeAny}, "udp:*"}, + {ProtoPortRange{Proto: 0x84, Ports: PortRange{999, 999}}, "sctp:999"}, + {ProtoPortRange{Proto: 0x3a, Ports: PortRangeAny}, "ipv6-icmp:*"}, + {ProtoPortRange{Proto: 0x21, Ports: PortRangeAny}, "dccp:*"}, + {ProtoPortRange{Proto: 0x2f, Ports: PortRangeAny}, "gre:*"}, + } + for _, tc := range tests { + if got := tc.input.String(); got != tc.want { + t.Errorf("String for %v: got %q, want %q", tc.input, got, tc.want) + } + } +} + +func TestProtoPortRangeRoundTrip(t *testing.T) { + tests := []struct { + input ProtoPortRange + text string + }{ + {ProtoPortRange{Ports: PortRangeAny}, "*"}, + {ProtoPortRange{Ports: PortRange{23, 23}}, "23"}, + {ProtoPortRange{Ports: PortRange{80, 120}}, "80-120"}, + {ProtoPortRange{Proto: 100, Ports: PortRange{80, 80}}, "100:80"}, + {ProtoPortRange{Proto: 200, Ports: PortRange{101, 105}}, "200:101-105"}, + {ProtoPortRange{Proto: 1, Ports: PortRangeAny}, "icmp:*"}, + {ProtoPortRange{Proto: 2, Ports: PortRangeAny}, "igmp:*"}, + {ProtoPortRange{Proto: 6, Ports: PortRange{10, 13}}, "tcp:10-13"}, + {ProtoPortRange{Proto: 17, Ports: PortRangeAny}, "udp:*"}, + {ProtoPortRange{Proto: 0x84, Ports: PortRange{999, 999}}, "sctp:999"}, + {ProtoPortRange{Proto: 0x3a, Ports: PortRangeAny}, "ipv6-icmp:*"}, + {ProtoPortRange{Proto: 0x21, Ports: PortRangeAny}, "dccp:*"}, + {ProtoPortRange{Proto: 0x2f, Ports: PortRangeAny}, "gre:*"}, + } + + for _, tc := range tests { + out, err := tc.input.MarshalText() + if err != nil { + t.Errorf("MarshalText for %v: %v", tc.input, err) + continue + } + if got := string(out); got != tc.text { + t.Errorf("MarshalText for %#v: got %q, want %q", tc.input, got, tc.text) + } + var ppr ProtoPortRange + if err := ppr.UnmarshalText(out); err != nil { + t.Errorf("UnmarshalText for %q: err=%v", tc.text, err) + continue + } + if ppr != tc.input { + t.Errorf("round trip error for %q: got %v, want %#v", tc.text, ppr, tc.input) + } + } +} diff --git a/tailcfg/tka.go b/tailcfg/tka.go index ca7e6be76..97fdcc0db 100644 --- a/tailcfg/tka.go +++ b/tailcfg/tka.go @@ -1,264 +1,264 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tailcfg - -import ( - "tailscale.com/types/key" - "tailscale.com/types/tkatype" -) - -// TKAInitBeginRequest submits a genesis AUM to seed the creation of the -// tailnet's key authority. -type TKAInitBeginRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // GenesisAUM is the initial (genesis) AUM that the node generated - // to bootstrap tailnet key authority state. - GenesisAUM tkatype.MarshaledAUM -} - -// TKASignInfo describes information about an existing node that needs -// to be signed into a node-key signature. -type TKASignInfo struct { - // NodeID is the ID of the node which needs a signature. It must - // correspond to NodePublic. - NodeID NodeID - // NodePublic is the node (Wireguard) public key which is being - // signed. - NodePublic key.NodePublic - - // RotationPubkey specifies the public key which may sign - // a NodeKeySignature (NKS), which rotates the node key. - // - // This is necessary so the node can rotate its node-key without - // talking to a node which holds a trusted network-lock key. - // It does this by nesting the original NKS in a 'rotation' NKS, - // which it then signs with the key corresponding to RotationPubkey. - // - // This field expects a raw ed25519 public key. - RotationPubkey []byte -} - -// TKAInitBeginResponse is the JSON response from a /tka/init/begin RPC. -// This structure describes node information which must be signed to -// complete initialization of the tailnets' key authority. -type TKAInitBeginResponse struct { - // NeedSignatures specify information about the nodes in your tailnet - // which need initial signatures to function once the tailnet key - // authority is in use. The generated signatures should then be - // submitted in a /tka/init/finish RPC. - NeedSignatures []TKASignInfo -} - -// TKAInitFinishRequest is the JSON request of a /tka/init/finish RPC. -// This RPC finalizes initialization of the tailnet key authority -// by submitting node-key signatures for all existing nodes. -type TKAInitFinishRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // Signatures are serialized tka.NodeKeySignatures for all nodes - // in the tailnet. - Signatures map[NodeID]tkatype.MarshaledSignature - - // SupportDisablement is a disablement secret for Tailscale support. - // This is only generated if --gen-disablement-for-support is specified - // in an invocation to 'tailscale lock init'. - SupportDisablement []byte `json:",omitempty"` -} - -// TKAInitFinishResponse is the JSON response from a /tka/init/finish RPC. -// This schema describes the successful enablement of the tailnet's -// key authority. -type TKAInitFinishResponse struct { - // Nothing. (yet?) -} - -// TKAInfo encodes the control plane's view of tailnet key authority (TKA) -// state. This information is transmitted as part of the MapResponse. -type TKAInfo struct { - // Head describes the hash of the latest AUM applied to the authority. - // Head is encoded as tka.AUMHash.MarshalText. - // - // If the Head state differs to that known locally, the node should perform - // synchronization via a separate RPC. - Head string `json:",omitempty"` - - // Disabled indicates the control plane believes TKA should be disabled, - // and the node should reach out to fetch a disablement - // secret. If the disablement secret verifies, then the node should then - // disable TKA locally. - // This field exists to disambiguate a nil TKAInfo in a delta mapresponse - // from a nil TKAInfo indicating TKA should be disabled. - Disabled bool `json:",omitempty"` -} - -// TKABootstrapRequest is sent by a node to get information necessary for -// enabling or disabling the tailnet key authority. -type TKABootstrapRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // Head represents the node's head AUMHash (tka.Authority.Head), if - // network lock is enabled. - Head string -} - -// TKABootstrapResponse encodes values necessary to enable or disable -// the tailnet key authority (TKA). -type TKABootstrapResponse struct { - // GenesisAUM returns the initial AUM necessary to initialize TKA. - GenesisAUM tkatype.MarshaledAUM `json:",omitempty"` - - // DisablementSecret encodes a secret necessary to disable TKA. - DisablementSecret []byte `json:",omitempty"` -} - -// TKASyncOfferRequest encodes a request to synchronize tailnet key authority -// state (TKA). Values of type tka.AUMHash are encoded as strings in their -// MarshalText form. -type TKASyncOfferRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // Head represents the node's head AUMHash (tka.Authority.Head). This - // corresponds to tka.SyncOffer.Head. - Head string - // Ancestors represents a selection of ancestor AUMHash values ascending - // from the current head. This corresponds to tka.SyncOffer.Ancestors. - Ancestors []string -} - -// TKASyncOfferResponse encodes a response in synchronizing a node's -// tailnet key authority state. Values of type tka.AUMHash are encoded as -// strings in their MarshalText form. -type TKASyncOfferResponse struct { - // Head represents the control plane's head AUMHash (tka.Authority.Head). - // This corresponds to tka.SyncOffer.Head. - Head string - // Ancestors represents a selection of ancestor AUMHash values ascending - // from the control plane's head. This corresponds to - // tka.SyncOffer.Ancestors. - Ancestors []string - // MissingAUMs encodes AUMs that the control plane believes the node - // is missing. - MissingAUMs []tkatype.MarshaledAUM -} - -// TKASyncSendRequest encodes AUMs that a node believes the control plane -// is missing, and notifies control of its local TKA state (specifically -// the head hash). -type TKASyncSendRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // Head represents the node's head AUMHash (tka.Authority.Head) after - // applying any AUMs from the sync-offer response. - // It is encoded as tka.AUMHash.MarshalText. - Head string - - // MissingAUMs encodes AUMs that the node believes the control plane - // is missing. - MissingAUMs []tkatype.MarshaledAUM - - // Interactive is true if additional error checking should be performed as - // the request is on behalf of an interactive operation (e.g., an - // administrator publishing new changes) as opposed to an automatic - // synchronization that may be reporting lost data. - Interactive bool -} - -// TKASyncSendResponse encodes the control plane's response to a node -// submitting AUMs during AUM synchronization. -type TKASyncSendResponse struct { - // Head represents the control plane's head AUMHash (tka.Authority.Head), - // after applying the missing AUMs. - Head string -} - -// TKADisableRequest disables network-lock across the tailnet using the -// provided disablement secret. -// -// This is the request schema for a /tka/disable noise RPC. -type TKADisableRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // Head represents the node's head AUMHash (tka.Authority.Head). - // It is encoded as tka.AUMHash.MarshalText. - Head string - - // DisablementSecret encodes the secret necessary to disable TKA. - DisablementSecret []byte -} - -// TKADisableResponse is the JSON response from a /tka/disable RPC. -// This schema describes the successful disablement of the tailnet's -// key authority. -type TKADisableResponse struct { - // Nothing. (yet?) -} - -// TKASubmitSignatureRequest transmits a node-key signature to the control plane. -// -// This is the request schema for a /tka/sign noise RPC. -type TKASubmitSignatureRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. The node-key which - // is being signed is embedded in Signature. - NodeKey key.NodePublic - - // Signature encodes the node-key signature being submitted. - Signature tkatype.MarshaledSignature -} - -// TKASubmitSignatureResponse is the JSON response from a /tka/sign RPC. -type TKASubmitSignatureResponse struct { - // Nothing. (yet?) -} - -// TKASignaturesUsingKeyRequest asks the control plane for -// all signatures which are signed by the provided keyID. -// -// This is the request schema for a /tka/affected-sigs RPC. -type TKASignaturesUsingKeyRequest struct { - // Version is the client's capabilities. - Version CapabilityVersion - - // NodeKey is the client's current node key. - NodeKey key.NodePublic - - // KeyID is the key we are querying using. - KeyID tkatype.KeyID -} - -// TKASignaturesUsingKeyResponse is the JSON response to -// a /tka/affected-sigs RPC. -// -// It enumerates all signatures which are signed by the -// queried keyID. -type TKASignaturesUsingKeyResponse struct { - Signatures []tkatype.MarshaledSignature -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailcfg + +import ( + "tailscale.com/types/key" + "tailscale.com/types/tkatype" +) + +// TKAInitBeginRequest submits a genesis AUM to seed the creation of the +// tailnet's key authority. +type TKAInitBeginRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // GenesisAUM is the initial (genesis) AUM that the node generated + // to bootstrap tailnet key authority state. + GenesisAUM tkatype.MarshaledAUM +} + +// TKASignInfo describes information about an existing node that needs +// to be signed into a node-key signature. +type TKASignInfo struct { + // NodeID is the ID of the node which needs a signature. It must + // correspond to NodePublic. + NodeID NodeID + // NodePublic is the node (Wireguard) public key which is being + // signed. + NodePublic key.NodePublic + + // RotationPubkey specifies the public key which may sign + // a NodeKeySignature (NKS), which rotates the node key. + // + // This is necessary so the node can rotate its node-key without + // talking to a node which holds a trusted network-lock key. + // It does this by nesting the original NKS in a 'rotation' NKS, + // which it then signs with the key corresponding to RotationPubkey. + // + // This field expects a raw ed25519 public key. + RotationPubkey []byte +} + +// TKAInitBeginResponse is the JSON response from a /tka/init/begin RPC. +// This structure describes node information which must be signed to +// complete initialization of the tailnets' key authority. +type TKAInitBeginResponse struct { + // NeedSignatures specify information about the nodes in your tailnet + // which need initial signatures to function once the tailnet key + // authority is in use. The generated signatures should then be + // submitted in a /tka/init/finish RPC. + NeedSignatures []TKASignInfo +} + +// TKAInitFinishRequest is the JSON request of a /tka/init/finish RPC. +// This RPC finalizes initialization of the tailnet key authority +// by submitting node-key signatures for all existing nodes. +type TKAInitFinishRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // Signatures are serialized tka.NodeKeySignatures for all nodes + // in the tailnet. + Signatures map[NodeID]tkatype.MarshaledSignature + + // SupportDisablement is a disablement secret for Tailscale support. + // This is only generated if --gen-disablement-for-support is specified + // in an invocation to 'tailscale lock init'. + SupportDisablement []byte `json:",omitempty"` +} + +// TKAInitFinishResponse is the JSON response from a /tka/init/finish RPC. +// This schema describes the successful enablement of the tailnet's +// key authority. +type TKAInitFinishResponse struct { + // Nothing. (yet?) +} + +// TKAInfo encodes the control plane's view of tailnet key authority (TKA) +// state. This information is transmitted as part of the MapResponse. +type TKAInfo struct { + // Head describes the hash of the latest AUM applied to the authority. + // Head is encoded as tka.AUMHash.MarshalText. + // + // If the Head state differs to that known locally, the node should perform + // synchronization via a separate RPC. + Head string `json:",omitempty"` + + // Disabled indicates the control plane believes TKA should be disabled, + // and the node should reach out to fetch a disablement + // secret. If the disablement secret verifies, then the node should then + // disable TKA locally. + // This field exists to disambiguate a nil TKAInfo in a delta mapresponse + // from a nil TKAInfo indicating TKA should be disabled. + Disabled bool `json:",omitempty"` +} + +// TKABootstrapRequest is sent by a node to get information necessary for +// enabling or disabling the tailnet key authority. +type TKABootstrapRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // Head represents the node's head AUMHash (tka.Authority.Head), if + // network lock is enabled. + Head string +} + +// TKABootstrapResponse encodes values necessary to enable or disable +// the tailnet key authority (TKA). +type TKABootstrapResponse struct { + // GenesisAUM returns the initial AUM necessary to initialize TKA. + GenesisAUM tkatype.MarshaledAUM `json:",omitempty"` + + // DisablementSecret encodes a secret necessary to disable TKA. + DisablementSecret []byte `json:",omitempty"` +} + +// TKASyncOfferRequest encodes a request to synchronize tailnet key authority +// state (TKA). Values of type tka.AUMHash are encoded as strings in their +// MarshalText form. +type TKASyncOfferRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // Head represents the node's head AUMHash (tka.Authority.Head). This + // corresponds to tka.SyncOffer.Head. + Head string + // Ancestors represents a selection of ancestor AUMHash values ascending + // from the current head. This corresponds to tka.SyncOffer.Ancestors. + Ancestors []string +} + +// TKASyncOfferResponse encodes a response in synchronizing a node's +// tailnet key authority state. Values of type tka.AUMHash are encoded as +// strings in their MarshalText form. +type TKASyncOfferResponse struct { + // Head represents the control plane's head AUMHash (tka.Authority.Head). + // This corresponds to tka.SyncOffer.Head. + Head string + // Ancestors represents a selection of ancestor AUMHash values ascending + // from the control plane's head. This corresponds to + // tka.SyncOffer.Ancestors. + Ancestors []string + // MissingAUMs encodes AUMs that the control plane believes the node + // is missing. + MissingAUMs []tkatype.MarshaledAUM +} + +// TKASyncSendRequest encodes AUMs that a node believes the control plane +// is missing, and notifies control of its local TKA state (specifically +// the head hash). +type TKASyncSendRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // Head represents the node's head AUMHash (tka.Authority.Head) after + // applying any AUMs from the sync-offer response. + // It is encoded as tka.AUMHash.MarshalText. + Head string + + // MissingAUMs encodes AUMs that the node believes the control plane + // is missing. + MissingAUMs []tkatype.MarshaledAUM + + // Interactive is true if additional error checking should be performed as + // the request is on behalf of an interactive operation (e.g., an + // administrator publishing new changes) as opposed to an automatic + // synchronization that may be reporting lost data. + Interactive bool +} + +// TKASyncSendResponse encodes the control plane's response to a node +// submitting AUMs during AUM synchronization. +type TKASyncSendResponse struct { + // Head represents the control plane's head AUMHash (tka.Authority.Head), + // after applying the missing AUMs. + Head string +} + +// TKADisableRequest disables network-lock across the tailnet using the +// provided disablement secret. +// +// This is the request schema for a /tka/disable noise RPC. +type TKADisableRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // Head represents the node's head AUMHash (tka.Authority.Head). + // It is encoded as tka.AUMHash.MarshalText. + Head string + + // DisablementSecret encodes the secret necessary to disable TKA. + DisablementSecret []byte +} + +// TKADisableResponse is the JSON response from a /tka/disable RPC. +// This schema describes the successful disablement of the tailnet's +// key authority. +type TKADisableResponse struct { + // Nothing. (yet?) +} + +// TKASubmitSignatureRequest transmits a node-key signature to the control plane. +// +// This is the request schema for a /tka/sign noise RPC. +type TKASubmitSignatureRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. The node-key which + // is being signed is embedded in Signature. + NodeKey key.NodePublic + + // Signature encodes the node-key signature being submitted. + Signature tkatype.MarshaledSignature +} + +// TKASubmitSignatureResponse is the JSON response from a /tka/sign RPC. +type TKASubmitSignatureResponse struct { + // Nothing. (yet?) +} + +// TKASignaturesUsingKeyRequest asks the control plane for +// all signatures which are signed by the provided keyID. +// +// This is the request schema for a /tka/affected-sigs RPC. +type TKASignaturesUsingKeyRequest struct { + // Version is the client's capabilities. + Version CapabilityVersion + + // NodeKey is the client's current node key. + NodeKey key.NodePublic + + // KeyID is the key we are querying using. + KeyID tkatype.KeyID +} + +// TKASignaturesUsingKeyResponse is the JSON response to +// a /tka/affected-sigs RPC. +// +// It enumerates all signatures which are signed by the +// queried keyID. +type TKASignaturesUsingKeyResponse struct { + Signatures []tkatype.MarshaledSignature +} diff --git a/taildrop/delete.go b/taildrop/delete.go index 7279a7687..aaef34df1 100644 --- a/taildrop/delete.go +++ b/taildrop/delete.go @@ -1,205 +1,205 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package taildrop - -import ( - "container/list" - "context" - "io/fs" - "os" - "path/filepath" - "strings" - "sync" - "time" - - "tailscale.com/ipn" - "tailscale.com/syncs" - "tailscale.com/tstime" - "tailscale.com/types/logger" -) - -// deleteDelay is the amount of time to wait before we delete a file. -// A shorter value ensures timely deletion of deleted and partial files, while -// a longer value provides more opportunity for partial files to be resumed. -const deleteDelay = time.Hour - -// fileDeleter manages asynchronous deletion of files after deleteDelay. -type fileDeleter struct { - logf logger.Logf - clock tstime.DefaultClock - dir string - event func(string) // called for certain events; for testing only - - mu sync.Mutex - queue list.List - byName map[string]*list.Element - - emptySignal chan struct{} // signal that the queue is empty - group syncs.WaitGroup - shutdownCtx context.Context - shutdown context.CancelFunc -} - -// deleteFile is a specific file to delete after deleteDelay. -type deleteFile struct { - name string - inserted time.Time -} - -func (d *fileDeleter) Init(m *Manager, eventHook func(string)) { - d.logf = m.opts.Logf - d.clock = m.opts.Clock - d.dir = m.opts.Dir - d.event = eventHook - - d.byName = make(map[string]*list.Element) - d.emptySignal = make(chan struct{}) - d.shutdownCtx, d.shutdown = context.WithCancel(context.Background()) - - // From a cold-start, load the list of partial and deleted files. - // - // Only run this if we have ever received at least one file - // to avoid ever touching the taildrop directory on systems (e.g., MacOS) - // that pop up a security dialog window upon first access. - if m.opts.State == nil { - return - } - if b, _ := m.opts.State.ReadState(ipn.TaildropReceivedKey); len(b) == 0 { - return - } - d.group.Go(func() { - d.event("start full-scan") - defer d.event("end full-scan") - rangeDir(d.dir, func(de fs.DirEntry) bool { - switch { - case d.shutdownCtx.Err() != nil: - return false // terminate early - case !de.Type().IsRegular(): - return true - case strings.HasSuffix(de.Name(), partialSuffix): - // Only enqueue the file for deletion if there is no active put. - nameID := strings.TrimSuffix(de.Name(), partialSuffix) - if i := strings.LastIndexByte(nameID, '.'); i > 0 { - key := incomingFileKey{ClientID(nameID[i+len("."):]), nameID[:i]} - m.incomingFiles.LoadFunc(key, func(_ *incomingFile, loaded bool) { - if !loaded { - d.Insert(de.Name()) - } - }) - } else { - d.Insert(de.Name()) - } - case strings.HasSuffix(de.Name(), deletedSuffix): - // Best-effort immediate deletion of deleted files. - name := strings.TrimSuffix(de.Name(), deletedSuffix) - if os.Remove(filepath.Join(d.dir, name)) == nil { - if os.Remove(filepath.Join(d.dir, de.Name())) == nil { - break - } - } - // Otherwise, enqueue the file for later deletion. - d.Insert(de.Name()) - } - return true - }) - }) -} - -// Insert enqueues baseName for eventual deletion. -func (d *fileDeleter) Insert(baseName string) { - d.mu.Lock() - defer d.mu.Unlock() - if d.shutdownCtx.Err() != nil { - return - } - if _, ok := d.byName[baseName]; ok { - return // already queued for deletion - } - d.byName[baseName] = d.queue.PushBack(&deleteFile{baseName, d.clock.Now()}) - if d.queue.Len() == 1 && d.shutdownCtx.Err() == nil { - d.group.Go(func() { d.waitAndDelete(deleteDelay) }) - } -} - -// waitAndDelete is an asynchronous deletion goroutine. -// At most one waitAndDelete routine is ever running at a time. -// It is not started unless there is at least one file in the queue. -func (d *fileDeleter) waitAndDelete(wait time.Duration) { - tc, ch := d.clock.NewTimer(wait) - defer tc.Stop() // cleanup the timer resource if we stop early - d.event("start waitAndDelete") - defer d.event("end waitAndDelete") - select { - case <-d.shutdownCtx.Done(): - case <-d.emptySignal: - case now := <-ch: - d.mu.Lock() - defer d.mu.Unlock() - - // Iterate over all files to delete, and delete anything old enough. - var next *list.Element - var failed []*list.Element - for elem := d.queue.Front(); elem != nil; elem = next { - next = elem.Next() - file := elem.Value.(*deleteFile) - if now.Sub(file.inserted) < deleteDelay { - break // everything after this is recently inserted - } - - // Delete the expired file. - if name, ok := strings.CutSuffix(file.name, deletedSuffix); ok { - if err := os.Remove(filepath.Join(d.dir, name)); err != nil && !os.IsNotExist(err) { - d.logf("could not delete: %v", redactError(err)) - failed = append(failed, elem) - continue - } - } - if err := os.Remove(filepath.Join(d.dir, file.name)); err != nil && !os.IsNotExist(err) { - d.logf("could not delete: %v", redactError(err)) - failed = append(failed, elem) - continue - } - d.queue.Remove(elem) - delete(d.byName, file.name) - d.event("deleted " + file.name) - } - for _, elem := range failed { - elem.Value.(*deleteFile).inserted = now // retry after deleteDelay - d.queue.MoveToBack(elem) - } - - // If there are still some files to delete, retry again later. - if d.queue.Len() > 0 && d.shutdownCtx.Err() == nil { - file := d.queue.Front().Value.(*deleteFile) - retryAfter := deleteDelay - now.Sub(file.inserted) - d.group.Go(func() { d.waitAndDelete(retryAfter) }) - } - } -} - -// Remove dequeues baseName from eventual deletion. -func (d *fileDeleter) Remove(baseName string) { - d.mu.Lock() - defer d.mu.Unlock() - if elem := d.byName[baseName]; elem != nil { - d.queue.Remove(elem) - delete(d.byName, baseName) - // Signal to terminate any waitAndDelete goroutines. - if d.queue.Len() == 0 { - select { - case <-d.shutdownCtx.Done(): - case d.emptySignal <- struct{}{}: - } - } - } -} - -// Shutdown shuts down the deleter. -// It blocks until all goroutines are stopped. -func (d *fileDeleter) Shutdown() { - d.mu.Lock() // acquire lock to ensure no new goroutines start after shutdown - d.shutdown() - d.mu.Unlock() - d.group.Wait() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "container/list" + "context" + "io/fs" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "tailscale.com/ipn" + "tailscale.com/syncs" + "tailscale.com/tstime" + "tailscale.com/types/logger" +) + +// deleteDelay is the amount of time to wait before we delete a file. +// A shorter value ensures timely deletion of deleted and partial files, while +// a longer value provides more opportunity for partial files to be resumed. +const deleteDelay = time.Hour + +// fileDeleter manages asynchronous deletion of files after deleteDelay. +type fileDeleter struct { + logf logger.Logf + clock tstime.DefaultClock + dir string + event func(string) // called for certain events; for testing only + + mu sync.Mutex + queue list.List + byName map[string]*list.Element + + emptySignal chan struct{} // signal that the queue is empty + group syncs.WaitGroup + shutdownCtx context.Context + shutdown context.CancelFunc +} + +// deleteFile is a specific file to delete after deleteDelay. +type deleteFile struct { + name string + inserted time.Time +} + +func (d *fileDeleter) Init(m *Manager, eventHook func(string)) { + d.logf = m.opts.Logf + d.clock = m.opts.Clock + d.dir = m.opts.Dir + d.event = eventHook + + d.byName = make(map[string]*list.Element) + d.emptySignal = make(chan struct{}) + d.shutdownCtx, d.shutdown = context.WithCancel(context.Background()) + + // From a cold-start, load the list of partial and deleted files. + // + // Only run this if we have ever received at least one file + // to avoid ever touching the taildrop directory on systems (e.g., MacOS) + // that pop up a security dialog window upon first access. + if m.opts.State == nil { + return + } + if b, _ := m.opts.State.ReadState(ipn.TaildropReceivedKey); len(b) == 0 { + return + } + d.group.Go(func() { + d.event("start full-scan") + defer d.event("end full-scan") + rangeDir(d.dir, func(de fs.DirEntry) bool { + switch { + case d.shutdownCtx.Err() != nil: + return false // terminate early + case !de.Type().IsRegular(): + return true + case strings.HasSuffix(de.Name(), partialSuffix): + // Only enqueue the file for deletion if there is no active put. + nameID := strings.TrimSuffix(de.Name(), partialSuffix) + if i := strings.LastIndexByte(nameID, '.'); i > 0 { + key := incomingFileKey{ClientID(nameID[i+len("."):]), nameID[:i]} + m.incomingFiles.LoadFunc(key, func(_ *incomingFile, loaded bool) { + if !loaded { + d.Insert(de.Name()) + } + }) + } else { + d.Insert(de.Name()) + } + case strings.HasSuffix(de.Name(), deletedSuffix): + // Best-effort immediate deletion of deleted files. + name := strings.TrimSuffix(de.Name(), deletedSuffix) + if os.Remove(filepath.Join(d.dir, name)) == nil { + if os.Remove(filepath.Join(d.dir, de.Name())) == nil { + break + } + } + // Otherwise, enqueue the file for later deletion. + d.Insert(de.Name()) + } + return true + }) + }) +} + +// Insert enqueues baseName for eventual deletion. +func (d *fileDeleter) Insert(baseName string) { + d.mu.Lock() + defer d.mu.Unlock() + if d.shutdownCtx.Err() != nil { + return + } + if _, ok := d.byName[baseName]; ok { + return // already queued for deletion + } + d.byName[baseName] = d.queue.PushBack(&deleteFile{baseName, d.clock.Now()}) + if d.queue.Len() == 1 && d.shutdownCtx.Err() == nil { + d.group.Go(func() { d.waitAndDelete(deleteDelay) }) + } +} + +// waitAndDelete is an asynchronous deletion goroutine. +// At most one waitAndDelete routine is ever running at a time. +// It is not started unless there is at least one file in the queue. +func (d *fileDeleter) waitAndDelete(wait time.Duration) { + tc, ch := d.clock.NewTimer(wait) + defer tc.Stop() // cleanup the timer resource if we stop early + d.event("start waitAndDelete") + defer d.event("end waitAndDelete") + select { + case <-d.shutdownCtx.Done(): + case <-d.emptySignal: + case now := <-ch: + d.mu.Lock() + defer d.mu.Unlock() + + // Iterate over all files to delete, and delete anything old enough. + var next *list.Element + var failed []*list.Element + for elem := d.queue.Front(); elem != nil; elem = next { + next = elem.Next() + file := elem.Value.(*deleteFile) + if now.Sub(file.inserted) < deleteDelay { + break // everything after this is recently inserted + } + + // Delete the expired file. + if name, ok := strings.CutSuffix(file.name, deletedSuffix); ok { + if err := os.Remove(filepath.Join(d.dir, name)); err != nil && !os.IsNotExist(err) { + d.logf("could not delete: %v", redactError(err)) + failed = append(failed, elem) + continue + } + } + if err := os.Remove(filepath.Join(d.dir, file.name)); err != nil && !os.IsNotExist(err) { + d.logf("could not delete: %v", redactError(err)) + failed = append(failed, elem) + continue + } + d.queue.Remove(elem) + delete(d.byName, file.name) + d.event("deleted " + file.name) + } + for _, elem := range failed { + elem.Value.(*deleteFile).inserted = now // retry after deleteDelay + d.queue.MoveToBack(elem) + } + + // If there are still some files to delete, retry again later. + if d.queue.Len() > 0 && d.shutdownCtx.Err() == nil { + file := d.queue.Front().Value.(*deleteFile) + retryAfter := deleteDelay - now.Sub(file.inserted) + d.group.Go(func() { d.waitAndDelete(retryAfter) }) + } + } +} + +// Remove dequeues baseName from eventual deletion. +func (d *fileDeleter) Remove(baseName string) { + d.mu.Lock() + defer d.mu.Unlock() + if elem := d.byName[baseName]; elem != nil { + d.queue.Remove(elem) + delete(d.byName, baseName) + // Signal to terminate any waitAndDelete goroutines. + if d.queue.Len() == 0 { + select { + case <-d.shutdownCtx.Done(): + case d.emptySignal <- struct{}{}: + } + } + } +} + +// Shutdown shuts down the deleter. +// It blocks until all goroutines are stopped. +func (d *fileDeleter) Shutdown() { + d.mu.Lock() // acquire lock to ensure no new goroutines start after shutdown + d.shutdown() + d.mu.Unlock() + d.group.Wait() +} diff --git a/taildrop/delete_test.go b/taildrop/delete_test.go index b40fa35bf..5fa4b9c37 100644 --- a/taildrop/delete_test.go +++ b/taildrop/delete_test.go @@ -1,152 +1,152 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package taildrop - -import ( - "os" - "path/filepath" - "slices" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "tailscale.com/ipn" - "tailscale.com/ipn/store/mem" - "tailscale.com/tstest" - "tailscale.com/tstime" - "tailscale.com/util/must" -) - -func TestDeleter(t *testing.T) { - dir := t.TempDir() - must.Do(touchFile(filepath.Join(dir, "foo.partial"))) - must.Do(touchFile(filepath.Join(dir, "bar.partial"))) - must.Do(touchFile(filepath.Join(dir, "fizz"))) - must.Do(touchFile(filepath.Join(dir, "fizz.deleted"))) - must.Do(touchFile(filepath.Join(dir, "buzz.deleted"))) // lacks a matching "buzz" file - - checkDirectory := func(want ...string) { - t.Helper() - var got []string - for _, de := range must.Get(os.ReadDir(dir)) { - got = append(got, de.Name()) - } - slices.Sort(got) - slices.Sort(want) - if diff := cmp.Diff(got, want); diff != "" { - t.Fatalf("directory mismatch (-got +want):\n%s", diff) - } - } - - clock := tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}) - advance := func(d time.Duration) { - t.Helper() - t.Logf("advance: %v", d) - clock.Advance(d) - } - - eventsChan := make(chan string, 1000) - checkEvents := func(want ...string) { - t.Helper() - tm := time.NewTimer(10 * time.Second) - defer tm.Stop() - var got []string - for range want { - select { - case event := <-eventsChan: - t.Logf("event: %s", event) - got = append(got, event) - case <-tm.C: - t.Fatalf("timed out waiting for event: got %v, want %v", got, want) - } - } - slices.Sort(got) - slices.Sort(want) - if diff := cmp.Diff(got, want); diff != "" { - t.Fatalf("events mismatch (-got +want):\n%s", diff) - } - } - eventHook := func(event string) { eventsChan <- event } - - var m Manager - var fd fileDeleter - m.opts.Logf = t.Logf - m.opts.Clock = tstime.DefaultClock{Clock: clock} - m.opts.Dir = dir - m.opts.State = must.Get(mem.New(nil, "")) - must.Do(m.opts.State.WriteState(ipn.TaildropReceivedKey, []byte{1})) - fd.Init(&m, eventHook) - defer fd.Shutdown() - insert := func(name string) { - t.Helper() - t.Logf("insert: %v", name) - fd.Insert(name) - } - remove := func(name string) { - t.Helper() - t.Logf("remove: %v", name) - fd.Remove(name) - } - - checkEvents("start full-scan") - checkEvents("end full-scan", "start waitAndDelete") - checkDirectory("foo.partial", "bar.partial", "buzz.deleted") - - advance(deleteDelay / 2) - checkDirectory("foo.partial", "bar.partial", "buzz.deleted") - advance(deleteDelay / 2) - checkEvents("deleted foo.partial", "deleted bar.partial", "deleted buzz.deleted") - checkEvents("end waitAndDelete") - checkDirectory() - - must.Do(touchFile(filepath.Join(dir, "one.partial"))) - insert("one.partial") - checkEvents("start waitAndDelete") - advance(deleteDelay / 4) - must.Do(touchFile(filepath.Join(dir, "two.partial"))) - insert("two.partial") - advance(deleteDelay / 4) - must.Do(touchFile(filepath.Join(dir, "three.partial"))) - insert("three.partial") - advance(deleteDelay / 4) - must.Do(touchFile(filepath.Join(dir, "four.partial"))) - insert("four.partial") - - advance(deleteDelay / 4) - checkEvents("deleted one.partial") - checkDirectory("two.partial", "three.partial", "four.partial") - checkEvents("end waitAndDelete", "start waitAndDelete") - - advance(deleteDelay / 4) - checkEvents("deleted two.partial") - checkDirectory("three.partial", "four.partial") - checkEvents("end waitAndDelete", "start waitAndDelete") - - advance(deleteDelay / 4) - checkEvents("deleted three.partial") - checkDirectory("four.partial") - checkEvents("end waitAndDelete", "start waitAndDelete") - - advance(deleteDelay / 4) - checkEvents("deleted four.partial") - checkDirectory() - checkEvents("end waitAndDelete") - - insert("wuzz.partial") - checkEvents("start waitAndDelete") - remove("wuzz.partial") - checkEvents("end waitAndDelete") -} - -// Test that the asynchronous full scan of the taildrop directory does not occur -// on a cold start if taildrop has never received any files. -func TestDeleterInitWithoutTaildrop(t *testing.T) { - var m Manager - var fd fileDeleter - m.opts.Logf = t.Logf - m.opts.Dir = t.TempDir() - m.opts.State = must.Get(mem.New(nil, "")) - fd.Init(&m, func(event string) { t.Errorf("unexpected event: %v", event) }) - fd.Shutdown() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "os" + "path/filepath" + "slices" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "tailscale.com/ipn" + "tailscale.com/ipn/store/mem" + "tailscale.com/tstest" + "tailscale.com/tstime" + "tailscale.com/util/must" +) + +func TestDeleter(t *testing.T) { + dir := t.TempDir() + must.Do(touchFile(filepath.Join(dir, "foo.partial"))) + must.Do(touchFile(filepath.Join(dir, "bar.partial"))) + must.Do(touchFile(filepath.Join(dir, "fizz"))) + must.Do(touchFile(filepath.Join(dir, "fizz.deleted"))) + must.Do(touchFile(filepath.Join(dir, "buzz.deleted"))) // lacks a matching "buzz" file + + checkDirectory := func(want ...string) { + t.Helper() + var got []string + for _, de := range must.Get(os.ReadDir(dir)) { + got = append(got, de.Name()) + } + slices.Sort(got) + slices.Sort(want) + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("directory mismatch (-got +want):\n%s", diff) + } + } + + clock := tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}) + advance := func(d time.Duration) { + t.Helper() + t.Logf("advance: %v", d) + clock.Advance(d) + } + + eventsChan := make(chan string, 1000) + checkEvents := func(want ...string) { + t.Helper() + tm := time.NewTimer(10 * time.Second) + defer tm.Stop() + var got []string + for range want { + select { + case event := <-eventsChan: + t.Logf("event: %s", event) + got = append(got, event) + case <-tm.C: + t.Fatalf("timed out waiting for event: got %v, want %v", got, want) + } + } + slices.Sort(got) + slices.Sort(want) + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("events mismatch (-got +want):\n%s", diff) + } + } + eventHook := func(event string) { eventsChan <- event } + + var m Manager + var fd fileDeleter + m.opts.Logf = t.Logf + m.opts.Clock = tstime.DefaultClock{Clock: clock} + m.opts.Dir = dir + m.opts.State = must.Get(mem.New(nil, "")) + must.Do(m.opts.State.WriteState(ipn.TaildropReceivedKey, []byte{1})) + fd.Init(&m, eventHook) + defer fd.Shutdown() + insert := func(name string) { + t.Helper() + t.Logf("insert: %v", name) + fd.Insert(name) + } + remove := func(name string) { + t.Helper() + t.Logf("remove: %v", name) + fd.Remove(name) + } + + checkEvents("start full-scan") + checkEvents("end full-scan", "start waitAndDelete") + checkDirectory("foo.partial", "bar.partial", "buzz.deleted") + + advance(deleteDelay / 2) + checkDirectory("foo.partial", "bar.partial", "buzz.deleted") + advance(deleteDelay / 2) + checkEvents("deleted foo.partial", "deleted bar.partial", "deleted buzz.deleted") + checkEvents("end waitAndDelete") + checkDirectory() + + must.Do(touchFile(filepath.Join(dir, "one.partial"))) + insert("one.partial") + checkEvents("start waitAndDelete") + advance(deleteDelay / 4) + must.Do(touchFile(filepath.Join(dir, "two.partial"))) + insert("two.partial") + advance(deleteDelay / 4) + must.Do(touchFile(filepath.Join(dir, "three.partial"))) + insert("three.partial") + advance(deleteDelay / 4) + must.Do(touchFile(filepath.Join(dir, "four.partial"))) + insert("four.partial") + + advance(deleteDelay / 4) + checkEvents("deleted one.partial") + checkDirectory("two.partial", "three.partial", "four.partial") + checkEvents("end waitAndDelete", "start waitAndDelete") + + advance(deleteDelay / 4) + checkEvents("deleted two.partial") + checkDirectory("three.partial", "four.partial") + checkEvents("end waitAndDelete", "start waitAndDelete") + + advance(deleteDelay / 4) + checkEvents("deleted three.partial") + checkDirectory("four.partial") + checkEvents("end waitAndDelete", "start waitAndDelete") + + advance(deleteDelay / 4) + checkEvents("deleted four.partial") + checkDirectory() + checkEvents("end waitAndDelete") + + insert("wuzz.partial") + checkEvents("start waitAndDelete") + remove("wuzz.partial") + checkEvents("end waitAndDelete") +} + +// Test that the asynchronous full scan of the taildrop directory does not occur +// on a cold start if taildrop has never received any files. +func TestDeleterInitWithoutTaildrop(t *testing.T) { + var m Manager + var fd fileDeleter + m.opts.Logf = t.Logf + m.opts.Dir = t.TempDir() + m.opts.State = must.Get(mem.New(nil, "")) + fd.Init(&m, func(event string) { t.Errorf("unexpected event: %v", event) }) + fd.Shutdown() +} diff --git a/taildrop/resume_test.go b/taildrop/resume_test.go index 8758ddd29..d366340eb 100644 --- a/taildrop/resume_test.go +++ b/taildrop/resume_test.go @@ -1,74 +1,74 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package taildrop - -import ( - "bytes" - "io" - "math/rand" - "os" - "testing" - "testing/iotest" - - "tailscale.com/util/must" -) - -func TestResume(t *testing.T) { - oldBlockSize := blockSize - defer func() { blockSize = oldBlockSize }() - blockSize = 256 - - m := ManagerOptions{Logf: t.Logf, Dir: t.TempDir()}.New() - defer m.Shutdown() - - rn := rand.New(rand.NewSource(0)) - want := make([]byte, 12345) - must.Get(io.ReadFull(rn, want)) - - t.Run("resume-noexist", func(t *testing.T) { - r := io.Reader(bytes.NewReader(want)) - - next, close, err := m.HashPartialFile("", "foo") - must.Do(err) - defer close() - offset, r, err := ResumeReader(r, next) - must.Do(err) - must.Do(close()) // Windows wants the file handle to be closed to rename it. - - must.Get(m.PutFile("", "foo", r, offset, -1)) - got := must.Get(os.ReadFile(must.Get(joinDir(m.opts.Dir, "foo")))) - if !bytes.Equal(got, want) { - t.Errorf("content mismatches") - } - }) - - t.Run("resume-retry", func(t *testing.T) { - rn := rand.New(rand.NewSource(0)) - for i := 0; true; i++ { - r := io.Reader(bytes.NewReader(want)) - - next, close, err := m.HashPartialFile("", "bar") - must.Do(err) - defer close() - offset, r, err := ResumeReader(r, next) - must.Do(err) - must.Do(close()) // Windows wants the file handle to be closed to rename it. - - numWant := rn.Int63n(min(int64(len(want))-offset, 1000) + 1) - if offset < int64(len(want)) { - r = io.MultiReader(io.LimitReader(r, numWant), iotest.ErrReader(io.ErrClosedPipe)) - } - if _, err := m.PutFile("", "bar", r, offset, -1); err == nil { - break - } - if i > 1000 { - t.Fatalf("too many iterations to complete the test") - } - } - got := must.Get(os.ReadFile(must.Get(joinDir(m.opts.Dir, "bar")))) - if !bytes.Equal(got, want) { - t.Errorf("content mismatches") - } - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "bytes" + "io" + "math/rand" + "os" + "testing" + "testing/iotest" + + "tailscale.com/util/must" +) + +func TestResume(t *testing.T) { + oldBlockSize := blockSize + defer func() { blockSize = oldBlockSize }() + blockSize = 256 + + m := ManagerOptions{Logf: t.Logf, Dir: t.TempDir()}.New() + defer m.Shutdown() + + rn := rand.New(rand.NewSource(0)) + want := make([]byte, 12345) + must.Get(io.ReadFull(rn, want)) + + t.Run("resume-noexist", func(t *testing.T) { + r := io.Reader(bytes.NewReader(want)) + + next, close, err := m.HashPartialFile("", "foo") + must.Do(err) + defer close() + offset, r, err := ResumeReader(r, next) + must.Do(err) + must.Do(close()) // Windows wants the file handle to be closed to rename it. + + must.Get(m.PutFile("", "foo", r, offset, -1)) + got := must.Get(os.ReadFile(must.Get(joinDir(m.opts.Dir, "foo")))) + if !bytes.Equal(got, want) { + t.Errorf("content mismatches") + } + }) + + t.Run("resume-retry", func(t *testing.T) { + rn := rand.New(rand.NewSource(0)) + for i := 0; true; i++ { + r := io.Reader(bytes.NewReader(want)) + + next, close, err := m.HashPartialFile("", "bar") + must.Do(err) + defer close() + offset, r, err := ResumeReader(r, next) + must.Do(err) + must.Do(close()) // Windows wants the file handle to be closed to rename it. + + numWant := rn.Int63n(min(int64(len(want))-offset, 1000) + 1) + if offset < int64(len(want)) { + r = io.MultiReader(io.LimitReader(r, numWant), iotest.ErrReader(io.ErrClosedPipe)) + } + if _, err := m.PutFile("", "bar", r, offset, -1); err == nil { + break + } + if i > 1000 { + t.Fatalf("too many iterations to complete the test") + } + } + got := must.Get(os.ReadFile(must.Get(joinDir(m.opts.Dir, "bar")))) + if !bytes.Equal(got, want) { + t.Errorf("content mismatches") + } + }) +} diff --git a/taildrop/retrieve.go b/taildrop/retrieve.go index 527f8caed..3e37b492a 100644 --- a/taildrop/retrieve.go +++ b/taildrop/retrieve.go @@ -1,178 +1,178 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package taildrop - -import ( - "context" - "errors" - "io" - "io/fs" - "os" - "path/filepath" - "runtime" - "sort" - "time" - - "tailscale.com/client/tailscale/apitype" - "tailscale.com/logtail/backoff" -) - -// HasFilesWaiting reports whether any files are buffered in [Handler.Dir]. -// This always returns false when [Handler.DirectFileMode] is false. -func (m *Manager) HasFilesWaiting() (has bool) { - if m == nil || m.opts.Dir == "" || m.opts.DirectFileMode { - return false - } - - // Optimization: this is usually empty, so avoid opening - // the directory and checking. We can't cache the actual - // has-files-or-not values as the macOS/iOS client might - // in the future use+delete the files directly. So only - // keep this negative cache. - totalReceived := m.totalReceived.Load() - if totalReceived == m.emptySince.Load() { - return false - } - - // Check whether there is at least one one waiting file. - err := rangeDir(m.opts.Dir, func(de fs.DirEntry) bool { - name := de.Name() - if isPartialOrDeleted(name) || !de.Type().IsRegular() { - return true - } - _, err := os.Stat(filepath.Join(m.opts.Dir, name+deletedSuffix)) - if os.IsNotExist(err) { - has = true - return false - } - return true - }) - - // If there are no more waiting files, record totalReceived as emptySince - // so that we can short-circuit the expensive directory traversal - // if no files have been received after the start of this call. - if err == nil && !has { - m.emptySince.Store(totalReceived) - } - return has -} - -// WaitingFiles returns the list of files that have been sent by a -// peer that are waiting in [Handler.Dir]. -// This always returns nil when [Handler.DirectFileMode] is false. -func (m *Manager) WaitingFiles() (ret []apitype.WaitingFile, err error) { - if m == nil || m.opts.Dir == "" { - return nil, ErrNoTaildrop - } - if m.opts.DirectFileMode { - return nil, nil - } - if err := rangeDir(m.opts.Dir, func(de fs.DirEntry) bool { - name := de.Name() - if isPartialOrDeleted(name) || !de.Type().IsRegular() { - return true - } - _, err := os.Stat(filepath.Join(m.opts.Dir, name+deletedSuffix)) - if os.IsNotExist(err) { - fi, err := de.Info() - if err != nil { - return true - } - ret = append(ret, apitype.WaitingFile{ - Name: filepath.Base(name), - Size: fi.Size(), - }) - } - return true - }); err != nil { - return nil, redactError(err) - } - sort.Slice(ret, func(i, j int) bool { return ret[i].Name < ret[j].Name }) - return ret, nil -} - -// DeleteFile deletes a file of the given baseName from [Handler.Dir]. -// This method is only allowed when [Handler.DirectFileMode] is false. -func (m *Manager) DeleteFile(baseName string) error { - if m == nil || m.opts.Dir == "" { - return ErrNoTaildrop - } - if m.opts.DirectFileMode { - return errors.New("deletes not allowed in direct mode") - } - path, err := joinDir(m.opts.Dir, baseName) - if err != nil { - return err - } - var bo *backoff.Backoff - logf := m.opts.Logf - t0 := m.opts.Clock.Now() - for { - err := os.Remove(path) - if err != nil && !os.IsNotExist(err) { - err = redactError(err) - // Put a retry loop around deletes on Windows. - // - // Windows file descriptor closes are effectively asynchronous, - // as a bunch of hooks run on/after close, - // and we can't necessarily delete the file for a while after close, - // as we need to wait for everybody to be done with it. - // On Windows, unlike Unix, a file can't be deleted if it's open anywhere. - // So try a few times but ultimately just leave a "foo.jpg.deleted" - // marker file to note that it's deleted and we clean it up later. - if runtime.GOOS == "windows" { - if bo == nil { - bo = backoff.NewBackoff("delete-retry", logf, 1*time.Second) - } - if m.opts.Clock.Since(t0) < 5*time.Second { - bo.BackOff(context.Background(), err) - continue - } - if err := touchFile(path + deletedSuffix); err != nil { - logf("peerapi: failed to leave deleted marker: %v", err) - } - m.deleter.Insert(baseName + deletedSuffix) - } - logf("peerapi: failed to DeleteFile: %v", err) - return err - } - return nil - } -} - -func touchFile(path string) error { - f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0666) - if err != nil { - return redactError(err) - } - return f.Close() -} - -// OpenFile opens a file of the given baseName from [Handler.Dir]. -// This method is only allowed when [Handler.DirectFileMode] is false. -func (m *Manager) OpenFile(baseName string) (rc io.ReadCloser, size int64, err error) { - if m == nil || m.opts.Dir == "" { - return nil, 0, ErrNoTaildrop - } - if m.opts.DirectFileMode { - return nil, 0, errors.New("opens not allowed in direct mode") - } - path, err := joinDir(m.opts.Dir, baseName) - if err != nil { - return nil, 0, err - } - if _, err := os.Stat(path + deletedSuffix); err == nil { - return nil, 0, redactError(&fs.PathError{Op: "open", Path: path, Err: fs.ErrNotExist}) - } - f, err := os.Open(path) - if err != nil { - return nil, 0, redactError(err) - } - fi, err := f.Stat() - if err != nil { - f.Close() - return nil, 0, redactError(err) - } - return f, fi.Size(), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "context" + "errors" + "io" + "io/fs" + "os" + "path/filepath" + "runtime" + "sort" + "time" + + "tailscale.com/client/tailscale/apitype" + "tailscale.com/logtail/backoff" +) + +// HasFilesWaiting reports whether any files are buffered in [Handler.Dir]. +// This always returns false when [Handler.DirectFileMode] is false. +func (m *Manager) HasFilesWaiting() (has bool) { + if m == nil || m.opts.Dir == "" || m.opts.DirectFileMode { + return false + } + + // Optimization: this is usually empty, so avoid opening + // the directory and checking. We can't cache the actual + // has-files-or-not values as the macOS/iOS client might + // in the future use+delete the files directly. So only + // keep this negative cache. + totalReceived := m.totalReceived.Load() + if totalReceived == m.emptySince.Load() { + return false + } + + // Check whether there is at least one one waiting file. + err := rangeDir(m.opts.Dir, func(de fs.DirEntry) bool { + name := de.Name() + if isPartialOrDeleted(name) || !de.Type().IsRegular() { + return true + } + _, err := os.Stat(filepath.Join(m.opts.Dir, name+deletedSuffix)) + if os.IsNotExist(err) { + has = true + return false + } + return true + }) + + // If there are no more waiting files, record totalReceived as emptySince + // so that we can short-circuit the expensive directory traversal + // if no files have been received after the start of this call. + if err == nil && !has { + m.emptySince.Store(totalReceived) + } + return has +} + +// WaitingFiles returns the list of files that have been sent by a +// peer that are waiting in [Handler.Dir]. +// This always returns nil when [Handler.DirectFileMode] is false. +func (m *Manager) WaitingFiles() (ret []apitype.WaitingFile, err error) { + if m == nil || m.opts.Dir == "" { + return nil, ErrNoTaildrop + } + if m.opts.DirectFileMode { + return nil, nil + } + if err := rangeDir(m.opts.Dir, func(de fs.DirEntry) bool { + name := de.Name() + if isPartialOrDeleted(name) || !de.Type().IsRegular() { + return true + } + _, err := os.Stat(filepath.Join(m.opts.Dir, name+deletedSuffix)) + if os.IsNotExist(err) { + fi, err := de.Info() + if err != nil { + return true + } + ret = append(ret, apitype.WaitingFile{ + Name: filepath.Base(name), + Size: fi.Size(), + }) + } + return true + }); err != nil { + return nil, redactError(err) + } + sort.Slice(ret, func(i, j int) bool { return ret[i].Name < ret[j].Name }) + return ret, nil +} + +// DeleteFile deletes a file of the given baseName from [Handler.Dir]. +// This method is only allowed when [Handler.DirectFileMode] is false. +func (m *Manager) DeleteFile(baseName string) error { + if m == nil || m.opts.Dir == "" { + return ErrNoTaildrop + } + if m.opts.DirectFileMode { + return errors.New("deletes not allowed in direct mode") + } + path, err := joinDir(m.opts.Dir, baseName) + if err != nil { + return err + } + var bo *backoff.Backoff + logf := m.opts.Logf + t0 := m.opts.Clock.Now() + for { + err := os.Remove(path) + if err != nil && !os.IsNotExist(err) { + err = redactError(err) + // Put a retry loop around deletes on Windows. + // + // Windows file descriptor closes are effectively asynchronous, + // as a bunch of hooks run on/after close, + // and we can't necessarily delete the file for a while after close, + // as we need to wait for everybody to be done with it. + // On Windows, unlike Unix, a file can't be deleted if it's open anywhere. + // So try a few times but ultimately just leave a "foo.jpg.deleted" + // marker file to note that it's deleted and we clean it up later. + if runtime.GOOS == "windows" { + if bo == nil { + bo = backoff.NewBackoff("delete-retry", logf, 1*time.Second) + } + if m.opts.Clock.Since(t0) < 5*time.Second { + bo.BackOff(context.Background(), err) + continue + } + if err := touchFile(path + deletedSuffix); err != nil { + logf("peerapi: failed to leave deleted marker: %v", err) + } + m.deleter.Insert(baseName + deletedSuffix) + } + logf("peerapi: failed to DeleteFile: %v", err) + return err + } + return nil + } +} + +func touchFile(path string) error { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0666) + if err != nil { + return redactError(err) + } + return f.Close() +} + +// OpenFile opens a file of the given baseName from [Handler.Dir]. +// This method is only allowed when [Handler.DirectFileMode] is false. +func (m *Manager) OpenFile(baseName string) (rc io.ReadCloser, size int64, err error) { + if m == nil || m.opts.Dir == "" { + return nil, 0, ErrNoTaildrop + } + if m.opts.DirectFileMode { + return nil, 0, errors.New("opens not allowed in direct mode") + } + path, err := joinDir(m.opts.Dir, baseName) + if err != nil { + return nil, 0, err + } + if _, err := os.Stat(path + deletedSuffix); err == nil { + return nil, 0, redactError(&fs.PathError{Op: "open", Path: path, Err: fs.ErrNotExist}) + } + f, err := os.Open(path) + if err != nil { + return nil, 0, redactError(err) + } + fi, err := f.Stat() + if err != nil { + f.Close() + return nil, 0, redactError(err) + } + return f, fi.Size(), nil +} diff --git a/tempfork/gliderlabs/ssh/LICENSE b/tempfork/gliderlabs/ssh/LICENSE index 80b2b2baa..4a03f02a2 100644 --- a/tempfork/gliderlabs/ssh/LICENSE +++ b/tempfork/gliderlabs/ssh/LICENSE @@ -1,27 +1,27 @@ -Copyright (c) 2016 Glider Labs. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Glider Labs nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +Copyright (c) 2016 Glider Labs. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Glider Labs nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/tempfork/gliderlabs/ssh/README.md b/tempfork/gliderlabs/ssh/README.md index ecef6b7c4..79b5b89fa 100644 --- a/tempfork/gliderlabs/ssh/README.md +++ b/tempfork/gliderlabs/ssh/README.md @@ -1,96 +1,96 @@ -# gliderlabs/ssh - -[![GoDoc](https://godoc.org/tailscale.com/tempfork/gliderlabs/ssh?status.svg)](https://godoc.org/github.com/gliderlabs/ssh) -[![CircleCI](https://img.shields.io/circleci/project/github/gliderlabs/ssh.svg)](https://circleci.com/gh/gliderlabs/ssh) -[![Go Report Card](https://goreportcard.com/badge/tailscale.com/tempfork/gliderlabs/ssh)](https://goreportcard.com/report/github.com/gliderlabs/ssh) -[![OpenCollective](https://opencollective.com/ssh/sponsors/badge.svg)](#sponsors) -[![Slack](http://slack.gliderlabs.com/badge.svg)](http://slack.gliderlabs.com) -[![Email Updates](https://img.shields.io/badge/updates-subscribe-yellow.svg)](https://app.convertkit.com/landing_pages/243312) - -> The Glider Labs SSH server package is dope. —[@bradfitz](https://twitter.com/bradfitz), Go team member - -This Go package wraps the [crypto/ssh -package](https://godoc.org/golang.org/x/crypto/ssh) with a higher-level API for -building SSH servers. The goal of the API was to make it as simple as using -[net/http](https://golang.org/pkg/net/http/), so the API is very similar: - -```go - package main - - import ( - "tailscale.com/tempfork/gliderlabs/ssh" - "io" - "log" - ) - - func main() { - ssh.Handle(func(s ssh.Session) { - io.WriteString(s, "Hello world\n") - }) - - log.Fatal(ssh.ListenAndServe(":2222", nil)) - } - -``` -This package was built by [@progrium](https://twitter.com/progrium) after working on nearly a dozen projects at Glider Labs using SSH and collaborating with [@shazow](https://twitter.com/shazow) (known for [ssh-chat](https://github.com/shazow/ssh-chat)). - -## Examples - -A bunch of great examples are in the `_examples` directory. - -## Usage - -[See GoDoc reference.](https://godoc.org/tailscale.com/tempfork/gliderlabs/ssh) - -## Contributing - -Pull requests are welcome! However, since this project is very much about API -design, please submit API changes as issues to discuss before submitting PRs. - -Also, you can [join our Slack](http://slack.gliderlabs.com) to discuss as well. - -## Roadmap - -* Non-session channel handlers -* Cleanup callback API -* 1.0 release -* High-level client? - -## Sponsors - -Become a sponsor and get your logo on our README on Github with a link to your site. [[Become a sponsor](https://opencollective.com/ssh#sponsor)] - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -## License - -[BSD](LICENSE) +# gliderlabs/ssh + +[![GoDoc](https://godoc.org/tailscale.com/tempfork/gliderlabs/ssh?status.svg)](https://godoc.org/github.com/gliderlabs/ssh) +[![CircleCI](https://img.shields.io/circleci/project/github/gliderlabs/ssh.svg)](https://circleci.com/gh/gliderlabs/ssh) +[![Go Report Card](https://goreportcard.com/badge/tailscale.com/tempfork/gliderlabs/ssh)](https://goreportcard.com/report/github.com/gliderlabs/ssh) +[![OpenCollective](https://opencollective.com/ssh/sponsors/badge.svg)](#sponsors) +[![Slack](http://slack.gliderlabs.com/badge.svg)](http://slack.gliderlabs.com) +[![Email Updates](https://img.shields.io/badge/updates-subscribe-yellow.svg)](https://app.convertkit.com/landing_pages/243312) + +> The Glider Labs SSH server package is dope. —[@bradfitz](https://twitter.com/bradfitz), Go team member + +This Go package wraps the [crypto/ssh +package](https://godoc.org/golang.org/x/crypto/ssh) with a higher-level API for +building SSH servers. The goal of the API was to make it as simple as using +[net/http](https://golang.org/pkg/net/http/), so the API is very similar: + +```go + package main + + import ( + "tailscale.com/tempfork/gliderlabs/ssh" + "io" + "log" + ) + + func main() { + ssh.Handle(func(s ssh.Session) { + io.WriteString(s, "Hello world\n") + }) + + log.Fatal(ssh.ListenAndServe(":2222", nil)) + } + +``` +This package was built by [@progrium](https://twitter.com/progrium) after working on nearly a dozen projects at Glider Labs using SSH and collaborating with [@shazow](https://twitter.com/shazow) (known for [ssh-chat](https://github.com/shazow/ssh-chat)). + +## Examples + +A bunch of great examples are in the `_examples` directory. + +## Usage + +[See GoDoc reference.](https://godoc.org/tailscale.com/tempfork/gliderlabs/ssh) + +## Contributing + +Pull requests are welcome! However, since this project is very much about API +design, please submit API changes as issues to discuss before submitting PRs. + +Also, you can [join our Slack](http://slack.gliderlabs.com) to discuss as well. + +## Roadmap + +* Non-session channel handlers +* Cleanup callback API +* 1.0 release +* High-level client? + +## Sponsors + +Become a sponsor and get your logo on our README on Github with a link to your site. [[Become a sponsor](https://opencollective.com/ssh#sponsor)] + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +## License + +[BSD](LICENSE) diff --git a/tempfork/gliderlabs/ssh/agent.go b/tempfork/gliderlabs/ssh/agent.go index 3da665292..86a5bce7f 100644 --- a/tempfork/gliderlabs/ssh/agent.go +++ b/tempfork/gliderlabs/ssh/agent.go @@ -1,83 +1,83 @@ -package ssh - -import ( - "io" - "net" - "os" - "path" - "sync" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -const ( - agentRequestType = "auth-agent-req@openssh.com" - agentChannelType = "auth-agent@openssh.com" - - agentTempDir = "auth-agent" - agentListenFile = "listener.sock" -) - -// contextKeyAgentRequest is an internal context key for storing if the -// client requested agent forwarding -var contextKeyAgentRequest = &contextKey{"auth-agent-req"} - -// SetAgentRequested sets up the session context so that AgentRequested -// returns true. -func SetAgentRequested(ctx Context) { - ctx.SetValue(contextKeyAgentRequest, true) -} - -// AgentRequested returns true if the client requested agent forwarding. -func AgentRequested(sess Session) bool { - return sess.Context().Value(contextKeyAgentRequest) == true -} - -// NewAgentListener sets up a temporary Unix socket that can be communicated -// to the session environment and used for forwarding connections. -func NewAgentListener() (net.Listener, error) { - dir, err := os.MkdirTemp("", agentTempDir) - if err != nil { - return nil, err - } - l, err := net.Listen("unix", path.Join(dir, agentListenFile)) - if err != nil { - return nil, err - } - return l, nil -} - -// ForwardAgentConnections takes connections from a listener to proxy into the -// session on the OpenSSH channel for agent connections. It blocks and services -// connections until the listener stop accepting. -func ForwardAgentConnections(l net.Listener, s Session) { - sshConn := s.Context().Value(ContextKeyConn).(gossh.Conn) - for { - conn, err := l.Accept() - if err != nil { - return - } - go func(conn net.Conn) { - defer conn.Close() - channel, reqs, err := sshConn.OpenChannel(agentChannelType, nil) - if err != nil { - return - } - defer channel.Close() - go gossh.DiscardRequests(reqs) - var wg sync.WaitGroup - wg.Add(2) - go func() { - io.Copy(conn, channel) - conn.(*net.UnixConn).CloseWrite() - wg.Done() - }() - go func() { - io.Copy(channel, conn) - channel.CloseWrite() - wg.Done() - }() - wg.Wait() - }(conn) - } -} +package ssh + +import ( + "io" + "net" + "os" + "path" + "sync" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +const ( + agentRequestType = "auth-agent-req@openssh.com" + agentChannelType = "auth-agent@openssh.com" + + agentTempDir = "auth-agent" + agentListenFile = "listener.sock" +) + +// contextKeyAgentRequest is an internal context key for storing if the +// client requested agent forwarding +var contextKeyAgentRequest = &contextKey{"auth-agent-req"} + +// SetAgentRequested sets up the session context so that AgentRequested +// returns true. +func SetAgentRequested(ctx Context) { + ctx.SetValue(contextKeyAgentRequest, true) +} + +// AgentRequested returns true if the client requested agent forwarding. +func AgentRequested(sess Session) bool { + return sess.Context().Value(contextKeyAgentRequest) == true +} + +// NewAgentListener sets up a temporary Unix socket that can be communicated +// to the session environment and used for forwarding connections. +func NewAgentListener() (net.Listener, error) { + dir, err := os.MkdirTemp("", agentTempDir) + if err != nil { + return nil, err + } + l, err := net.Listen("unix", path.Join(dir, agentListenFile)) + if err != nil { + return nil, err + } + return l, nil +} + +// ForwardAgentConnections takes connections from a listener to proxy into the +// session on the OpenSSH channel for agent connections. It blocks and services +// connections until the listener stop accepting. +func ForwardAgentConnections(l net.Listener, s Session) { + sshConn := s.Context().Value(ContextKeyConn).(gossh.Conn) + for { + conn, err := l.Accept() + if err != nil { + return + } + go func(conn net.Conn) { + defer conn.Close() + channel, reqs, err := sshConn.OpenChannel(agentChannelType, nil) + if err != nil { + return + } + defer channel.Close() + go gossh.DiscardRequests(reqs) + var wg sync.WaitGroup + wg.Add(2) + go func() { + io.Copy(conn, channel) + conn.(*net.UnixConn).CloseWrite() + wg.Done() + }() + go func() { + io.Copy(channel, conn) + channel.CloseWrite() + wg.Done() + }() + wg.Wait() + }(conn) + } +} diff --git a/tempfork/gliderlabs/ssh/conn.go b/tempfork/gliderlabs/ssh/conn.go index ec277bf27..ebef8845b 100644 --- a/tempfork/gliderlabs/ssh/conn.go +++ b/tempfork/gliderlabs/ssh/conn.go @@ -1,55 +1,55 @@ -package ssh - -import ( - "context" - "net" - "time" -) - -type serverConn struct { - net.Conn - - idleTimeout time.Duration - maxDeadline time.Time - closeCanceler context.CancelFunc -} - -func (c *serverConn) Write(p []byte) (n int, err error) { - c.updateDeadline() - n, err = c.Conn.Write(p) - if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { - c.closeCanceler() - } - return -} - -func (c *serverConn) Read(b []byte) (n int, err error) { - c.updateDeadline() - n, err = c.Conn.Read(b) - if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { - c.closeCanceler() - } - return -} - -func (c *serverConn) Close() (err error) { - err = c.Conn.Close() - if c.closeCanceler != nil { - c.closeCanceler() - } - return -} - -func (c *serverConn) updateDeadline() { - switch { - case c.idleTimeout > 0: - idleDeadline := time.Now().Add(c.idleTimeout) - if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() { - c.Conn.SetDeadline(idleDeadline) - return - } - fallthrough - default: - c.Conn.SetDeadline(c.maxDeadline) - } -} +package ssh + +import ( + "context" + "net" + "time" +) + +type serverConn struct { + net.Conn + + idleTimeout time.Duration + maxDeadline time.Time + closeCanceler context.CancelFunc +} + +func (c *serverConn) Write(p []byte) (n int, err error) { + c.updateDeadline() + n, err = c.Conn.Write(p) + if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { + c.closeCanceler() + } + return +} + +func (c *serverConn) Read(b []byte) (n int, err error) { + c.updateDeadline() + n, err = c.Conn.Read(b) + if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { + c.closeCanceler() + } + return +} + +func (c *serverConn) Close() (err error) { + err = c.Conn.Close() + if c.closeCanceler != nil { + c.closeCanceler() + } + return +} + +func (c *serverConn) updateDeadline() { + switch { + case c.idleTimeout > 0: + idleDeadline := time.Now().Add(c.idleTimeout) + if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() { + c.Conn.SetDeadline(idleDeadline) + return + } + fallthrough + default: + c.Conn.SetDeadline(c.maxDeadline) + } +} diff --git a/tempfork/gliderlabs/ssh/context.go b/tempfork/gliderlabs/ssh/context.go index 6f7245574..d43de6f09 100644 --- a/tempfork/gliderlabs/ssh/context.go +++ b/tempfork/gliderlabs/ssh/context.go @@ -1,164 +1,164 @@ -package ssh - -import ( - "context" - "encoding/hex" - "net" - "sync" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -// contextKey is a value for use with context.WithValue. It's used as -// a pointer so it fits in an interface{} without allocation. -type contextKey struct { - name string -} - -var ( - // ContextKeyUser is a context key for use with Contexts in this package. - // The associated value will be of type string. - ContextKeyUser = &contextKey{"user"} - - // ContextKeySessionID is a context key for use with Contexts in this package. - // The associated value will be of type string. - ContextKeySessionID = &contextKey{"session-id"} - - // ContextKeyPermissions is a context key for use with Contexts in this package. - // The associated value will be of type *Permissions. - ContextKeyPermissions = &contextKey{"permissions"} - - // ContextKeyClientVersion is a context key for use with Contexts in this package. - // The associated value will be of type string. - ContextKeyClientVersion = &contextKey{"client-version"} - - // ContextKeyServerVersion is a context key for use with Contexts in this package. - // The associated value will be of type string. - ContextKeyServerVersion = &contextKey{"server-version"} - - // ContextKeyLocalAddr is a context key for use with Contexts in this package. - // The associated value will be of type net.Addr. - ContextKeyLocalAddr = &contextKey{"local-addr"} - - // ContextKeyRemoteAddr is a context key for use with Contexts in this package. - // The associated value will be of type net.Addr. - ContextKeyRemoteAddr = &contextKey{"remote-addr"} - - // ContextKeyServer is a context key for use with Contexts in this package. - // The associated value will be of type *Server. - ContextKeyServer = &contextKey{"ssh-server"} - - // ContextKeyConn is a context key for use with Contexts in this package. - // The associated value will be of type gossh.ServerConn. - ContextKeyConn = &contextKey{"ssh-conn"} - - // ContextKeyPublicKey is a context key for use with Contexts in this package. - // The associated value will be of type PublicKey. - ContextKeyPublicKey = &contextKey{"public-key"} - - ContextKeySendAuthBanner = &contextKey{"send-auth-banner"} -) - -// Context is a package specific context interface. It exposes connection -// metadata and allows new values to be easily written to it. It's used in -// authentication handlers and callbacks, and its underlying context.Context is -// exposed on Session in the session Handler. A connection-scoped lock is also -// embedded in the context to make it easier to limit operations per-connection. -type Context interface { - context.Context - sync.Locker - - // User returns the username used when establishing the SSH connection. - User() string - - // SessionID returns the session hash. - SessionID() string - - // ClientVersion returns the version reported by the client. - ClientVersion() string - - // ServerVersion returns the version reported by the server. - ServerVersion() string - - // RemoteAddr returns the remote address for this connection. - RemoteAddr() net.Addr - - // LocalAddr returns the local address for this connection. - LocalAddr() net.Addr - - // Permissions returns the Permissions object used for this connection. - Permissions() *Permissions - - // SetValue allows you to easily write new values into the underlying context. - SetValue(key, value interface{}) - - SendAuthBanner(banner string) error -} - -type sshContext struct { - context.Context - *sync.Mutex -} - -func newContext(srv *Server) (*sshContext, context.CancelFunc) { - innerCtx, cancel := context.WithCancel(context.Background()) - ctx := &sshContext{innerCtx, &sync.Mutex{}} - ctx.SetValue(ContextKeyServer, srv) - perms := &Permissions{&gossh.Permissions{}} - ctx.SetValue(ContextKeyPermissions, perms) - return ctx, cancel -} - -// this is separate from newContext because we will get ConnMetadata -// at different points so it needs to be applied separately -func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) { - if ctx.Value(ContextKeySessionID) != nil { - return - } - ctx.SetValue(ContextKeySessionID, hex.EncodeToString(conn.SessionID())) - ctx.SetValue(ContextKeyClientVersion, string(conn.ClientVersion())) - ctx.SetValue(ContextKeyServerVersion, string(conn.ServerVersion())) - ctx.SetValue(ContextKeyUser, conn.User()) - ctx.SetValue(ContextKeyLocalAddr, conn.LocalAddr()) - ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr()) - ctx.SetValue(ContextKeySendAuthBanner, conn.SendAuthBanner) -} - -func (ctx *sshContext) SetValue(key, value interface{}) { - ctx.Context = context.WithValue(ctx.Context, key, value) -} - -func (ctx *sshContext) User() string { - return ctx.Value(ContextKeyUser).(string) -} - -func (ctx *sshContext) SessionID() string { - return ctx.Value(ContextKeySessionID).(string) -} - -func (ctx *sshContext) ClientVersion() string { - return ctx.Value(ContextKeyClientVersion).(string) -} - -func (ctx *sshContext) ServerVersion() string { - return ctx.Value(ContextKeyServerVersion).(string) -} - -func (ctx *sshContext) RemoteAddr() net.Addr { - if addr, ok := ctx.Value(ContextKeyRemoteAddr).(net.Addr); ok { - return addr - } - return nil -} - -func (ctx *sshContext) LocalAddr() net.Addr { - return ctx.Value(ContextKeyLocalAddr).(net.Addr) -} - -func (ctx *sshContext) Permissions() *Permissions { - return ctx.Value(ContextKeyPermissions).(*Permissions) -} - -func (ctx *sshContext) SendAuthBanner(msg string) error { - return ctx.Value(ContextKeySendAuthBanner).(func(string) error)(msg) -} +package ssh + +import ( + "context" + "encoding/hex" + "net" + "sync" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. +type contextKey struct { + name string +} + +var ( + // ContextKeyUser is a context key for use with Contexts in this package. + // The associated value will be of type string. + ContextKeyUser = &contextKey{"user"} + + // ContextKeySessionID is a context key for use with Contexts in this package. + // The associated value will be of type string. + ContextKeySessionID = &contextKey{"session-id"} + + // ContextKeyPermissions is a context key for use with Contexts in this package. + // The associated value will be of type *Permissions. + ContextKeyPermissions = &contextKey{"permissions"} + + // ContextKeyClientVersion is a context key for use with Contexts in this package. + // The associated value will be of type string. + ContextKeyClientVersion = &contextKey{"client-version"} + + // ContextKeyServerVersion is a context key for use with Contexts in this package. + // The associated value will be of type string. + ContextKeyServerVersion = &contextKey{"server-version"} + + // ContextKeyLocalAddr is a context key for use with Contexts in this package. + // The associated value will be of type net.Addr. + ContextKeyLocalAddr = &contextKey{"local-addr"} + + // ContextKeyRemoteAddr is a context key for use with Contexts in this package. + // The associated value will be of type net.Addr. + ContextKeyRemoteAddr = &contextKey{"remote-addr"} + + // ContextKeyServer is a context key for use with Contexts in this package. + // The associated value will be of type *Server. + ContextKeyServer = &contextKey{"ssh-server"} + + // ContextKeyConn is a context key for use with Contexts in this package. + // The associated value will be of type gossh.ServerConn. + ContextKeyConn = &contextKey{"ssh-conn"} + + // ContextKeyPublicKey is a context key for use with Contexts in this package. + // The associated value will be of type PublicKey. + ContextKeyPublicKey = &contextKey{"public-key"} + + ContextKeySendAuthBanner = &contextKey{"send-auth-banner"} +) + +// Context is a package specific context interface. It exposes connection +// metadata and allows new values to be easily written to it. It's used in +// authentication handlers and callbacks, and its underlying context.Context is +// exposed on Session in the session Handler. A connection-scoped lock is also +// embedded in the context to make it easier to limit operations per-connection. +type Context interface { + context.Context + sync.Locker + + // User returns the username used when establishing the SSH connection. + User() string + + // SessionID returns the session hash. + SessionID() string + + // ClientVersion returns the version reported by the client. + ClientVersion() string + + // ServerVersion returns the version reported by the server. + ServerVersion() string + + // RemoteAddr returns the remote address for this connection. + RemoteAddr() net.Addr + + // LocalAddr returns the local address for this connection. + LocalAddr() net.Addr + + // Permissions returns the Permissions object used for this connection. + Permissions() *Permissions + + // SetValue allows you to easily write new values into the underlying context. + SetValue(key, value interface{}) + + SendAuthBanner(banner string) error +} + +type sshContext struct { + context.Context + *sync.Mutex +} + +func newContext(srv *Server) (*sshContext, context.CancelFunc) { + innerCtx, cancel := context.WithCancel(context.Background()) + ctx := &sshContext{innerCtx, &sync.Mutex{}} + ctx.SetValue(ContextKeyServer, srv) + perms := &Permissions{&gossh.Permissions{}} + ctx.SetValue(ContextKeyPermissions, perms) + return ctx, cancel +} + +// this is separate from newContext because we will get ConnMetadata +// at different points so it needs to be applied separately +func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) { + if ctx.Value(ContextKeySessionID) != nil { + return + } + ctx.SetValue(ContextKeySessionID, hex.EncodeToString(conn.SessionID())) + ctx.SetValue(ContextKeyClientVersion, string(conn.ClientVersion())) + ctx.SetValue(ContextKeyServerVersion, string(conn.ServerVersion())) + ctx.SetValue(ContextKeyUser, conn.User()) + ctx.SetValue(ContextKeyLocalAddr, conn.LocalAddr()) + ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr()) + ctx.SetValue(ContextKeySendAuthBanner, conn.SendAuthBanner) +} + +func (ctx *sshContext) SetValue(key, value interface{}) { + ctx.Context = context.WithValue(ctx.Context, key, value) +} + +func (ctx *sshContext) User() string { + return ctx.Value(ContextKeyUser).(string) +} + +func (ctx *sshContext) SessionID() string { + return ctx.Value(ContextKeySessionID).(string) +} + +func (ctx *sshContext) ClientVersion() string { + return ctx.Value(ContextKeyClientVersion).(string) +} + +func (ctx *sshContext) ServerVersion() string { + return ctx.Value(ContextKeyServerVersion).(string) +} + +func (ctx *sshContext) RemoteAddr() net.Addr { + if addr, ok := ctx.Value(ContextKeyRemoteAddr).(net.Addr); ok { + return addr + } + return nil +} + +func (ctx *sshContext) LocalAddr() net.Addr { + return ctx.Value(ContextKeyLocalAddr).(net.Addr) +} + +func (ctx *sshContext) Permissions() *Permissions { + return ctx.Value(ContextKeyPermissions).(*Permissions) +} + +func (ctx *sshContext) SendAuthBanner(msg string) error { + return ctx.Value(ContextKeySendAuthBanner).(func(string) error)(msg) +} diff --git a/tempfork/gliderlabs/ssh/context_test.go b/tempfork/gliderlabs/ssh/context_test.go index 8f71c3958..dcbd326b7 100644 --- a/tempfork/gliderlabs/ssh/context_test.go +++ b/tempfork/gliderlabs/ssh/context_test.go @@ -1,49 +1,49 @@ -//go:build glidertests - -package ssh - -import "testing" - -func TestSetPermissions(t *testing.T) { - t.Parallel() - permsExt := map[string]string{ - "foo": "bar", - } - session, _, cleanup := newTestSessionWithOptions(t, &Server{ - Handler: func(s Session) { - if _, ok := s.Permissions().Extensions["foo"]; !ok { - t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt) - } - }, - }, nil, PasswordAuth(func(ctx Context, password string) bool { - ctx.Permissions().Extensions = permsExt - return true - })) - defer cleanup() - if err := session.Run(""); err != nil { - t.Fatal(err) - } -} - -func TestSetValue(t *testing.T) { - t.Parallel() - value := map[string]string{ - "foo": "bar", - } - key := "testValue" - session, _, cleanup := newTestSessionWithOptions(t, &Server{ - Handler: func(s Session) { - v := s.Context().Value(key).(map[string]string) - if v["foo"] != value["foo"] { - t.Fatalf("got %#v; want %#v", v, value) - } - }, - }, nil, PasswordAuth(func(ctx Context, password string) bool { - ctx.SetValue(key, value) - return true - })) - defer cleanup() - if err := session.Run(""); err != nil { - t.Fatal(err) - } -} +//go:build glidertests + +package ssh + +import "testing" + +func TestSetPermissions(t *testing.T) { + t.Parallel() + permsExt := map[string]string{ + "foo": "bar", + } + session, _, cleanup := newTestSessionWithOptions(t, &Server{ + Handler: func(s Session) { + if _, ok := s.Permissions().Extensions["foo"]; !ok { + t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt) + } + }, + }, nil, PasswordAuth(func(ctx Context, password string) bool { + ctx.Permissions().Extensions = permsExt + return true + })) + defer cleanup() + if err := session.Run(""); err != nil { + t.Fatal(err) + } +} + +func TestSetValue(t *testing.T) { + t.Parallel() + value := map[string]string{ + "foo": "bar", + } + key := "testValue" + session, _, cleanup := newTestSessionWithOptions(t, &Server{ + Handler: func(s Session) { + v := s.Context().Value(key).(map[string]string) + if v["foo"] != value["foo"] { + t.Fatalf("got %#v; want %#v", v, value) + } + }, + }, nil, PasswordAuth(func(ctx Context, password string) bool { + ctx.SetValue(key, value) + return true + })) + defer cleanup() + if err := session.Run(""); err != nil { + t.Fatal(err) + } +} diff --git a/tempfork/gliderlabs/ssh/doc.go b/tempfork/gliderlabs/ssh/doc.go index 46c47d650..d13919176 100644 --- a/tempfork/gliderlabs/ssh/doc.go +++ b/tempfork/gliderlabs/ssh/doc.go @@ -1,45 +1,45 @@ -/* -Package ssh wraps the crypto/ssh package with a higher-level API for building -SSH servers. The goal of the API was to make it as simple as using net/http, so -the API is very similar. - -You should be able to build any SSH server using only this package, which wraps -relevant types and some functions from crypto/ssh. However, you still need to -use crypto/ssh for building SSH clients. - -ListenAndServe starts an SSH server with a given address, handler, and options. The -handler is usually nil, which means to use DefaultHandler. Handle sets DefaultHandler: - - ssh.Handle(func(s ssh.Session) { - io.WriteString(s, "Hello world\n") - }) - - log.Fatal(ssh.ListenAndServe(":2222", nil)) - -If you don't specify a host key, it will generate one every time. This is convenient -except you'll have to deal with clients being confused that the host key is different. -It's a better idea to generate or point to an existing key on your system: - - log.Fatal(ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/Users/progrium/.ssh/id_rsa"))) - -Although all options have functional option helpers, another way to control the -server's behavior is by creating a custom Server: - - s := &ssh.Server{ - Addr: ":2222", - Handler: sessionHandler, - PublicKeyHandler: authHandler, - } - s.AddHostKey(hostKeySigner) - - log.Fatal(s.ListenAndServe()) - -This package automatically handles basic SSH requests like setting environment -variables, requesting PTY, and changing window size. These requests are -processed, responded to, and any relevant state is updated. This state is then -exposed to you via the Session interface. - -The one big feature missing from the Session abstraction is signals. This was -started, but not completed. Pull Requests welcome! -*/ -package ssh +/* +Package ssh wraps the crypto/ssh package with a higher-level API for building +SSH servers. The goal of the API was to make it as simple as using net/http, so +the API is very similar. + +You should be able to build any SSH server using only this package, which wraps +relevant types and some functions from crypto/ssh. However, you still need to +use crypto/ssh for building SSH clients. + +ListenAndServe starts an SSH server with a given address, handler, and options. The +handler is usually nil, which means to use DefaultHandler. Handle sets DefaultHandler: + + ssh.Handle(func(s ssh.Session) { + io.WriteString(s, "Hello world\n") + }) + + log.Fatal(ssh.ListenAndServe(":2222", nil)) + +If you don't specify a host key, it will generate one every time. This is convenient +except you'll have to deal with clients being confused that the host key is different. +It's a better idea to generate or point to an existing key on your system: + + log.Fatal(ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/Users/progrium/.ssh/id_rsa"))) + +Although all options have functional option helpers, another way to control the +server's behavior is by creating a custom Server: + + s := &ssh.Server{ + Addr: ":2222", + Handler: sessionHandler, + PublicKeyHandler: authHandler, + } + s.AddHostKey(hostKeySigner) + + log.Fatal(s.ListenAndServe()) + +This package automatically handles basic SSH requests like setting environment +variables, requesting PTY, and changing window size. These requests are +processed, responded to, and any relevant state is updated. This state is then +exposed to you via the Session interface. + +The one big feature missing from the Session abstraction is signals. This was +started, but not completed. Pull Requests welcome! +*/ +package ssh diff --git a/tempfork/gliderlabs/ssh/example_test.go b/tempfork/gliderlabs/ssh/example_test.go index 61ffebbc0..c174bc4ae 100644 --- a/tempfork/gliderlabs/ssh/example_test.go +++ b/tempfork/gliderlabs/ssh/example_test.go @@ -1,50 +1,50 @@ -package ssh_test - -import ( - "errors" - "io" - "os" - - "tailscale.com/tempfork/gliderlabs/ssh" -) - -func ExampleListenAndServe() { - ssh.ListenAndServe(":2222", func(s ssh.Session) { - io.WriteString(s, "Hello world\n") - }) -} - -func ExamplePasswordAuth() { - ssh.ListenAndServe(":2222", nil, - ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool { - return pass == "secret" - }), - ) -} - -func ExampleNoPty() { - ssh.ListenAndServe(":2222", nil, ssh.NoPty()) -} - -func ExamplePublicKeyAuth() { - ssh.ListenAndServe(":2222", nil, - ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) error { - data, err := os.ReadFile("/path/to/allowed/key.pub") - if err != nil { - return err - } - allowed, _, _, _, err := ssh.ParseAuthorizedKey(data) - if err != nil { - return err - } - if !ssh.KeysEqual(key, allowed) { - return errors.New("some error") - } - return nil - }), - ) -} - -func ExampleHostKeyFile() { - ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/path/to/host/key")) -} +package ssh_test + +import ( + "errors" + "io" + "os" + + "tailscale.com/tempfork/gliderlabs/ssh" +) + +func ExampleListenAndServe() { + ssh.ListenAndServe(":2222", func(s ssh.Session) { + io.WriteString(s, "Hello world\n") + }) +} + +func ExamplePasswordAuth() { + ssh.ListenAndServe(":2222", nil, + ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool { + return pass == "secret" + }), + ) +} + +func ExampleNoPty() { + ssh.ListenAndServe(":2222", nil, ssh.NoPty()) +} + +func ExamplePublicKeyAuth() { + ssh.ListenAndServe(":2222", nil, + ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) error { + data, err := os.ReadFile("/path/to/allowed/key.pub") + if err != nil { + return err + } + allowed, _, _, _, err := ssh.ParseAuthorizedKey(data) + if err != nil { + return err + } + if !ssh.KeysEqual(key, allowed) { + return errors.New("some error") + } + return nil + }), + ) +} + +func ExampleHostKeyFile() { + ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/path/to/host/key")) +} diff --git a/tempfork/gliderlabs/ssh/options.go b/tempfork/gliderlabs/ssh/options.go index bb24909be..aa87a4f39 100644 --- a/tempfork/gliderlabs/ssh/options.go +++ b/tempfork/gliderlabs/ssh/options.go @@ -1,84 +1,84 @@ -package ssh - -import ( - "os" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -// PasswordAuth returns a functional option that sets PasswordHandler on the server. -func PasswordAuth(fn PasswordHandler) Option { - return func(srv *Server) error { - srv.PasswordHandler = fn - return nil - } -} - -// PublicKeyAuth returns a functional option that sets PublicKeyHandler on the server. -func PublicKeyAuth(fn PublicKeyHandler) Option { - return func(srv *Server) error { - srv.PublicKeyHandler = fn - return nil - } -} - -// HostKeyFile returns a functional option that adds HostSigners to the server -// from a PEM file at filepath. -func HostKeyFile(filepath string) Option { - return func(srv *Server) error { - pemBytes, err := os.ReadFile(filepath) - if err != nil { - return err - } - - signer, err := gossh.ParsePrivateKey(pemBytes) - if err != nil { - return err - } - - srv.AddHostKey(signer) - - return nil - } -} - -func KeyboardInteractiveAuth(fn KeyboardInteractiveHandler) Option { - return func(srv *Server) error { - srv.KeyboardInteractiveHandler = fn - return nil - } -} - -// HostKeyPEM returns a functional option that adds HostSigners to the server -// from a PEM file as bytes. -func HostKeyPEM(bytes []byte) Option { - return func(srv *Server) error { - signer, err := gossh.ParsePrivateKey(bytes) - if err != nil { - return err - } - - srv.AddHostKey(signer) - - return nil - } -} - -// NoPty returns a functional option that sets PtyCallback to return false, -// denying PTY requests. -func NoPty() Option { - return func(srv *Server) error { - srv.PtyCallback = func(ctx Context, pty Pty) bool { - return false - } - return nil - } -} - -// WrapConn returns a functional option that sets ConnCallback on the server. -func WrapConn(fn ConnCallback) Option { - return func(srv *Server) error { - srv.ConnCallback = fn - return nil - } -} +package ssh + +import ( + "os" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +// PasswordAuth returns a functional option that sets PasswordHandler on the server. +func PasswordAuth(fn PasswordHandler) Option { + return func(srv *Server) error { + srv.PasswordHandler = fn + return nil + } +} + +// PublicKeyAuth returns a functional option that sets PublicKeyHandler on the server. +func PublicKeyAuth(fn PublicKeyHandler) Option { + return func(srv *Server) error { + srv.PublicKeyHandler = fn + return nil + } +} + +// HostKeyFile returns a functional option that adds HostSigners to the server +// from a PEM file at filepath. +func HostKeyFile(filepath string) Option { + return func(srv *Server) error { + pemBytes, err := os.ReadFile(filepath) + if err != nil { + return err + } + + signer, err := gossh.ParsePrivateKey(pemBytes) + if err != nil { + return err + } + + srv.AddHostKey(signer) + + return nil + } +} + +func KeyboardInteractiveAuth(fn KeyboardInteractiveHandler) Option { + return func(srv *Server) error { + srv.KeyboardInteractiveHandler = fn + return nil + } +} + +// HostKeyPEM returns a functional option that adds HostSigners to the server +// from a PEM file as bytes. +func HostKeyPEM(bytes []byte) Option { + return func(srv *Server) error { + signer, err := gossh.ParsePrivateKey(bytes) + if err != nil { + return err + } + + srv.AddHostKey(signer) + + return nil + } +} + +// NoPty returns a functional option that sets PtyCallback to return false, +// denying PTY requests. +func NoPty() Option { + return func(srv *Server) error { + srv.PtyCallback = func(ctx Context, pty Pty) bool { + return false + } + return nil + } +} + +// WrapConn returns a functional option that sets ConnCallback on the server. +func WrapConn(fn ConnCallback) Option { + return func(srv *Server) error { + srv.ConnCallback = fn + return nil + } +} diff --git a/tempfork/gliderlabs/ssh/options_test.go b/tempfork/gliderlabs/ssh/options_test.go index 3aa2f1cf5..7cf6f376c 100644 --- a/tempfork/gliderlabs/ssh/options_test.go +++ b/tempfork/gliderlabs/ssh/options_test.go @@ -1,111 +1,111 @@ -//go:build glidertests - -package ssh - -import ( - "net" - "strings" - "sync/atomic" - "testing" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, *gossh.Client, func()) { - for _, option := range options { - if err := srv.SetOption(option); err != nil { - t.Fatal(err) - } - } - return newTestSession(t, srv, cfg) -} - -func TestPasswordAuth(t *testing.T) { - t.Parallel() - testUser := "testuser" - testPass := "testpass" - session, _, cleanup := newTestSessionWithOptions(t, &Server{ - Handler: func(s Session) { - // noop - }, - }, &gossh.ClientConfig{ - User: testUser, - Auth: []gossh.AuthMethod{ - gossh.Password(testPass), - }, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), - }, PasswordAuth(func(ctx Context, password string) bool { - if ctx.User() != testUser { - t.Fatalf("user = %#v; want %#v", ctx.User(), testUser) - } - if password != testPass { - t.Fatalf("user = %#v; want %#v", password, testPass) - } - return true - })) - defer cleanup() - if err := session.Run(""); err != nil { - t.Fatal(err) - } -} - -func TestPasswordAuthBadPass(t *testing.T) { - t.Parallel() - l := newLocalListener() - srv := &Server{Handler: func(s Session) {}} - srv.SetOption(PasswordAuth(func(ctx Context, password string) bool { - return false - })) - go srv.serveOnce(l) - _, err := gossh.Dial("tcp", l.Addr().String(), &gossh.ClientConfig{ - User: "testuser", - Auth: []gossh.AuthMethod{ - gossh.Password("testpass"), - }, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), - }) - if err != nil { - if !strings.Contains(err.Error(), "unable to authenticate") { - t.Fatal(err) - } - } -} - -type wrappedConn struct { - net.Conn - written int32 -} - -func (c *wrappedConn) Write(p []byte) (n int, err error) { - n, err = c.Conn.Write(p) - atomic.AddInt32(&(c.written), int32(n)) - return -} - -func TestConnWrapping(t *testing.T) { - t.Parallel() - var wrapped *wrappedConn - session, _, cleanup := newTestSessionWithOptions(t, &Server{ - Handler: func(s Session) { - // nothing - }, - }, &gossh.ClientConfig{ - User: "testuser", - Auth: []gossh.AuthMethod{ - gossh.Password("testpass"), - }, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), - }, PasswordAuth(func(ctx Context, password string) bool { - return true - }), WrapConn(func(ctx Context, conn net.Conn) net.Conn { - wrapped = &wrappedConn{conn, 0} - return wrapped - })) - defer cleanup() - if err := session.Shell(); err != nil { - t.Fatal(err) - } - if atomic.LoadInt32(&(wrapped.written)) == 0 { - t.Fatal("wrapped conn not written to") - } -} +//go:build glidertests + +package ssh + +import ( + "net" + "strings" + "sync/atomic" + "testing" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, *gossh.Client, func()) { + for _, option := range options { + if err := srv.SetOption(option); err != nil { + t.Fatal(err) + } + } + return newTestSession(t, srv, cfg) +} + +func TestPasswordAuth(t *testing.T) { + t.Parallel() + testUser := "testuser" + testPass := "testpass" + session, _, cleanup := newTestSessionWithOptions(t, &Server{ + Handler: func(s Session) { + // noop + }, + }, &gossh.ClientConfig{ + User: testUser, + Auth: []gossh.AuthMethod{ + gossh.Password(testPass), + }, + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }, PasswordAuth(func(ctx Context, password string) bool { + if ctx.User() != testUser { + t.Fatalf("user = %#v; want %#v", ctx.User(), testUser) + } + if password != testPass { + t.Fatalf("user = %#v; want %#v", password, testPass) + } + return true + })) + defer cleanup() + if err := session.Run(""); err != nil { + t.Fatal(err) + } +} + +func TestPasswordAuthBadPass(t *testing.T) { + t.Parallel() + l := newLocalListener() + srv := &Server{Handler: func(s Session) {}} + srv.SetOption(PasswordAuth(func(ctx Context, password string) bool { + return false + })) + go srv.serveOnce(l) + _, err := gossh.Dial("tcp", l.Addr().String(), &gossh.ClientConfig{ + User: "testuser", + Auth: []gossh.AuthMethod{ + gossh.Password("testpass"), + }, + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }) + if err != nil { + if !strings.Contains(err.Error(), "unable to authenticate") { + t.Fatal(err) + } + } +} + +type wrappedConn struct { + net.Conn + written int32 +} + +func (c *wrappedConn) Write(p []byte) (n int, err error) { + n, err = c.Conn.Write(p) + atomic.AddInt32(&(c.written), int32(n)) + return +} + +func TestConnWrapping(t *testing.T) { + t.Parallel() + var wrapped *wrappedConn + session, _, cleanup := newTestSessionWithOptions(t, &Server{ + Handler: func(s Session) { + // nothing + }, + }, &gossh.ClientConfig{ + User: "testuser", + Auth: []gossh.AuthMethod{ + gossh.Password("testpass"), + }, + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + }, PasswordAuth(func(ctx Context, password string) bool { + return true + }), WrapConn(func(ctx Context, conn net.Conn) net.Conn { + wrapped = &wrappedConn{conn, 0} + return wrapped + })) + defer cleanup() + if err := session.Shell(); err != nil { + t.Fatal(err) + } + if atomic.LoadInt32(&(wrapped.written)) == 0 { + t.Fatal("wrapped conn not written to") + } +} diff --git a/tempfork/gliderlabs/ssh/server.go b/tempfork/gliderlabs/ssh/server.go index 32f633e87..1086a72ca 100644 --- a/tempfork/gliderlabs/ssh/server.go +++ b/tempfork/gliderlabs/ssh/server.go @@ -1,459 +1,459 @@ -package ssh - -import ( - "context" - "errors" - "fmt" - "net" - "sync" - "time" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -// ErrServerClosed is returned by the Server's Serve, ListenAndServe, -// and ListenAndServeTLS methods after a call to Shutdown or Close. -var ErrServerClosed = errors.New("ssh: Server closed") - -type SubsystemHandler func(s Session) - -var DefaultSubsystemHandlers = map[string]SubsystemHandler{} - -type RequestHandler func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) - -var DefaultRequestHandlers = map[string]RequestHandler{} - -type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) - -var DefaultChannelHandlers = map[string]ChannelHandler{ - "session": DefaultSessionHandler, -} - -// Server defines parameters for running an SSH server. The zero value for -// Server is a valid configuration. When both PasswordHandler and -// PublicKeyHandler are nil, no client authentication is performed. -type Server struct { - Addr string // TCP address to listen on, ":22" if empty - Handler Handler // handler to invoke, ssh.DefaultHandler if nil - HostSigners []Signer // private keys for the host key, must have at least one - Version string // server version to be sent before the initial handshake - - KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler - PasswordHandler PasswordHandler // password authentication handler - PublicKeyHandler PublicKeyHandler // public key authentication handler - NoClientAuthHandler NoClientAuthHandler // no client authentication handler - PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil - ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling - LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil - ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil - ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options - SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions - - ConnectionFailedCallback ConnectionFailedCallback // callback to report connection failures - - IdleTimeout time.Duration // connection timeout when no activity, none if empty - MaxTimeout time.Duration // absolute connection timeout, none if empty - - // ChannelHandlers allow overriding the built-in session handlers or provide - // extensions to the protocol, such as tcpip forwarding. By default only the - // "session" handler is enabled. - ChannelHandlers map[string]ChannelHandler - - // RequestHandlers allow overriding the server-level request handlers or - // provide extensions to the protocol, such as tcpip forwarding. By default - // no handlers are enabled. - RequestHandlers map[string]RequestHandler - - // SubsystemHandlers are handlers which are similar to the usual SSH command - // handlers, but handle named subsystems. - SubsystemHandlers map[string]SubsystemHandler - - listenerWg sync.WaitGroup - mu sync.RWMutex - listeners map[net.Listener]struct{} - conns map[*gossh.ServerConn]struct{} - connWg sync.WaitGroup - doneChan chan struct{} -} - -func (srv *Server) ensureHostSigner() error { - srv.mu.Lock() - defer srv.mu.Unlock() - - if len(srv.HostSigners) == 0 { - signer, err := generateSigner() - if err != nil { - return err - } - srv.HostSigners = append(srv.HostSigners, signer) - } - return nil -} - -func (srv *Server) ensureHandlers() { - srv.mu.Lock() - defer srv.mu.Unlock() - - if srv.RequestHandlers == nil { - srv.RequestHandlers = map[string]RequestHandler{} - for k, v := range DefaultRequestHandlers { - srv.RequestHandlers[k] = v - } - } - if srv.ChannelHandlers == nil { - srv.ChannelHandlers = map[string]ChannelHandler{} - for k, v := range DefaultChannelHandlers { - srv.ChannelHandlers[k] = v - } - } - if srv.SubsystemHandlers == nil { - srv.SubsystemHandlers = map[string]SubsystemHandler{} - for k, v := range DefaultSubsystemHandlers { - srv.SubsystemHandlers[k] = v - } - } -} - -func (srv *Server) config(ctx Context) *gossh.ServerConfig { - srv.mu.RLock() - defer srv.mu.RUnlock() - - var config *gossh.ServerConfig - if srv.ServerConfigCallback == nil { - config = &gossh.ServerConfig{} - } else { - config = srv.ServerConfigCallback(ctx) - } - for _, signer := range srv.HostSigners { - config.AddHostKey(signer) - } - if srv.PasswordHandler == nil && srv.PublicKeyHandler == nil && srv.KeyboardInteractiveHandler == nil { - config.NoClientAuth = true - } - if srv.Version != "" { - config.ServerVersion = "SSH-2.0-" + srv.Version - } - if srv.PasswordHandler != nil { - config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) { - applyConnMetadata(ctx, conn) - if ok := srv.PasswordHandler(ctx, string(password)); !ok { - return ctx.Permissions().Permissions, fmt.Errorf("permission denied") - } - return ctx.Permissions().Permissions, nil - } - } - if srv.PublicKeyHandler != nil { - config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) { - applyConnMetadata(ctx, conn) - if err := srv.PublicKeyHandler(ctx, key); err != nil { - return ctx.Permissions().Permissions, err - } - ctx.SetValue(ContextKeyPublicKey, key) - return ctx.Permissions().Permissions, nil - } - } - if srv.KeyboardInteractiveHandler != nil { - config.KeyboardInteractiveCallback = func(conn gossh.ConnMetadata, challenger gossh.KeyboardInteractiveChallenge) (*gossh.Permissions, error) { - applyConnMetadata(ctx, conn) - if ok := srv.KeyboardInteractiveHandler(ctx, challenger); !ok { - return ctx.Permissions().Permissions, fmt.Errorf("permission denied") - } - return ctx.Permissions().Permissions, nil - } - } - if srv.NoClientAuthHandler != nil { - config.NoClientAuthCallback = func(conn gossh.ConnMetadata) (*gossh.Permissions, error) { - applyConnMetadata(ctx, conn) - if err := srv.NoClientAuthHandler(ctx); err != nil { - return ctx.Permissions().Permissions, err - } - return ctx.Permissions().Permissions, nil - } - } - return config -} - -// Handle sets the Handler for the server. -func (srv *Server) Handle(fn Handler) { - srv.mu.Lock() - defer srv.mu.Unlock() - - srv.Handler = fn -} - -// Close immediately closes all active listeners and all active -// connections. -// -// Close returns any error returned from closing the Server's -// underlying Listener(s). -func (srv *Server) Close() error { - srv.mu.Lock() - defer srv.mu.Unlock() - - srv.closeDoneChanLocked() - err := srv.closeListenersLocked() - for c := range srv.conns { - c.Close() - delete(srv.conns, c) - } - return err -} - -// Shutdown gracefully shuts down the server without interrupting any -// active connections. Shutdown works by first closing all open -// listeners, and then waiting indefinitely for connections to close. -// If the provided context expires before the shutdown is complete, -// then the context's error is returned. -func (srv *Server) Shutdown(ctx context.Context) error { - srv.mu.Lock() - lnerr := srv.closeListenersLocked() - srv.closeDoneChanLocked() - srv.mu.Unlock() - - finished := make(chan struct{}, 1) - go func() { - srv.listenerWg.Wait() - srv.connWg.Wait() - finished <- struct{}{} - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case <-finished: - return lnerr - } -} - -// Serve accepts incoming connections on the Listener l, creating a new -// connection goroutine for each. The connection goroutines read requests and then -// calls srv.Handler to handle sessions. -// -// Serve always returns a non-nil error. -func (srv *Server) Serve(l net.Listener) error { - srv.ensureHandlers() - defer l.Close() - if err := srv.ensureHostSigner(); err != nil { - return err - } - if srv.Handler == nil { - srv.Handler = DefaultHandler - } - var tempDelay time.Duration - - srv.trackListener(l, true) - defer srv.trackListener(l, false) - for { - conn, e := l.Accept() - if e != nil { - select { - case <-srv.getDoneChan(): - return ErrServerClosed - default: - } - if ne, ok := e.(net.Error); ok && ne.Temporary() { - if tempDelay == 0 { - tempDelay = 5 * time.Millisecond - } else { - tempDelay *= 2 - } - if max := 1 * time.Second; tempDelay > max { - tempDelay = max - } - time.Sleep(tempDelay) - continue - } - return e - } - go srv.HandleConn(conn) - } -} - -func (srv *Server) HandleConn(newConn net.Conn) { - ctx, cancel := newContext(srv) - if srv.ConnCallback != nil { - cbConn := srv.ConnCallback(ctx, newConn) - if cbConn == nil { - newConn.Close() - return - } - newConn = cbConn - } - conn := &serverConn{ - Conn: newConn, - idleTimeout: srv.IdleTimeout, - closeCanceler: cancel, - } - if srv.MaxTimeout > 0 { - conn.maxDeadline = time.Now().Add(srv.MaxTimeout) - } - defer conn.Close() - sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx)) - if err != nil { - if srv.ConnectionFailedCallback != nil { - srv.ConnectionFailedCallback(conn, err) - } - return - } - - srv.trackConn(sshConn, true) - defer srv.trackConn(sshConn, false) - - ctx.SetValue(ContextKeyConn, sshConn) - applyConnMetadata(ctx, sshConn) - //go gossh.DiscardRequests(reqs) - go srv.handleRequests(ctx, reqs) - for ch := range chans { - handler := srv.ChannelHandlers[ch.ChannelType()] - if handler == nil { - handler = srv.ChannelHandlers["default"] - } - if handler == nil { - ch.Reject(gossh.UnknownChannelType, "unsupported channel type") - continue - } - go handler(srv, sshConn, ch, ctx) - } -} - -func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) { - for req := range in { - handler := srv.RequestHandlers[req.Type] - if handler == nil { - handler = srv.RequestHandlers["default"] - } - if handler == nil { - req.Reply(false, nil) - continue - } - /*reqCtx, cancel := context.WithCancel(ctx) - defer cancel() */ - ret, payload := handler(ctx, srv, req) - req.Reply(ret, payload) - } -} - -// ListenAndServe listens on the TCP network address srv.Addr and then calls -// Serve to handle incoming connections. If srv.Addr is blank, ":22" is used. -// ListenAndServe always returns a non-nil error. -func (srv *Server) ListenAndServe() error { - addr := srv.Addr - if addr == "" { - addr = ":22" - } - ln, err := net.Listen("tcp", addr) - if err != nil { - return err - } - return srv.Serve(ln) -} - -// AddHostKey adds a private key as a host key. If an existing host key exists -// with the same algorithm, it is overwritten. Each server config must have at -// least one host key. -func (srv *Server) AddHostKey(key Signer) { - srv.mu.Lock() - defer srv.mu.Unlock() - - // these are later added via AddHostKey on ServerConfig, which performs the - // check for one of every algorithm. - - // This check is based on the AddHostKey method from the x/crypto/ssh - // library. This allows us to only keep one active key for each type on a - // server at once. So, if you're dynamically updating keys at runtime, this - // list will not keep growing. - for i, k := range srv.HostSigners { - if k.PublicKey().Type() == key.PublicKey().Type() { - srv.HostSigners[i] = key - return - } - } - - srv.HostSigners = append(srv.HostSigners, key) -} - -// SetOption runs a functional option against the server. -func (srv *Server) SetOption(option Option) error { - // NOTE: there is a potential race here for any option that doesn't call an - // internal method. We can't actually lock here because if something calls - // (as an example) AddHostKey, it will deadlock. - - //srv.mu.Lock() - //defer srv.mu.Unlock() - - return option(srv) -} - -func (srv *Server) getDoneChan() <-chan struct{} { - srv.mu.Lock() - defer srv.mu.Unlock() - - return srv.getDoneChanLocked() -} - -func (srv *Server) getDoneChanLocked() chan struct{} { - if srv.doneChan == nil { - srv.doneChan = make(chan struct{}) - } - return srv.doneChan -} - -func (srv *Server) closeDoneChanLocked() { - ch := srv.getDoneChanLocked() - select { - case <-ch: - // Already closed. Don't close again. - default: - // Safe to close here. We're the only closer, guarded - // by srv.mu. - close(ch) - } -} - -func (srv *Server) closeListenersLocked() error { - var err error - for ln := range srv.listeners { - if cerr := ln.Close(); cerr != nil && err == nil { - err = cerr - } - delete(srv.listeners, ln) - } - return err -} - -func (srv *Server) trackListener(ln net.Listener, add bool) { - srv.mu.Lock() - defer srv.mu.Unlock() - - if srv.listeners == nil { - srv.listeners = make(map[net.Listener]struct{}) - } - if add { - // If the *Server is being reused after a previous - // Close or Shutdown, reset its doneChan: - if len(srv.listeners) == 0 && len(srv.conns) == 0 { - srv.doneChan = nil - } - srv.listeners[ln] = struct{}{} - srv.listenerWg.Add(1) - } else { - delete(srv.listeners, ln) - srv.listenerWg.Done() - } -} - -func (srv *Server) trackConn(c *gossh.ServerConn, add bool) { - srv.mu.Lock() - defer srv.mu.Unlock() - - if srv.conns == nil { - srv.conns = make(map[*gossh.ServerConn]struct{}) - } - if add { - srv.conns[c] = struct{}{} - srv.connWg.Add(1) - } else { - delete(srv.conns, c) - srv.connWg.Done() - } -} +package ssh + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +// ErrServerClosed is returned by the Server's Serve, ListenAndServe, +// and ListenAndServeTLS methods after a call to Shutdown or Close. +var ErrServerClosed = errors.New("ssh: Server closed") + +type SubsystemHandler func(s Session) + +var DefaultSubsystemHandlers = map[string]SubsystemHandler{} + +type RequestHandler func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) + +var DefaultRequestHandlers = map[string]RequestHandler{} + +type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) + +var DefaultChannelHandlers = map[string]ChannelHandler{ + "session": DefaultSessionHandler, +} + +// Server defines parameters for running an SSH server. The zero value for +// Server is a valid configuration. When both PasswordHandler and +// PublicKeyHandler are nil, no client authentication is performed. +type Server struct { + Addr string // TCP address to listen on, ":22" if empty + Handler Handler // handler to invoke, ssh.DefaultHandler if nil + HostSigners []Signer // private keys for the host key, must have at least one + Version string // server version to be sent before the initial handshake + + KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler + PasswordHandler PasswordHandler // password authentication handler + PublicKeyHandler PublicKeyHandler // public key authentication handler + NoClientAuthHandler NoClientAuthHandler // no client authentication handler + PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil + ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling + LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil + ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil + ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options + SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions + + ConnectionFailedCallback ConnectionFailedCallback // callback to report connection failures + + IdleTimeout time.Duration // connection timeout when no activity, none if empty + MaxTimeout time.Duration // absolute connection timeout, none if empty + + // ChannelHandlers allow overriding the built-in session handlers or provide + // extensions to the protocol, such as tcpip forwarding. By default only the + // "session" handler is enabled. + ChannelHandlers map[string]ChannelHandler + + // RequestHandlers allow overriding the server-level request handlers or + // provide extensions to the protocol, such as tcpip forwarding. By default + // no handlers are enabled. + RequestHandlers map[string]RequestHandler + + // SubsystemHandlers are handlers which are similar to the usual SSH command + // handlers, but handle named subsystems. + SubsystemHandlers map[string]SubsystemHandler + + listenerWg sync.WaitGroup + mu sync.RWMutex + listeners map[net.Listener]struct{} + conns map[*gossh.ServerConn]struct{} + connWg sync.WaitGroup + doneChan chan struct{} +} + +func (srv *Server) ensureHostSigner() error { + srv.mu.Lock() + defer srv.mu.Unlock() + + if len(srv.HostSigners) == 0 { + signer, err := generateSigner() + if err != nil { + return err + } + srv.HostSigners = append(srv.HostSigners, signer) + } + return nil +} + +func (srv *Server) ensureHandlers() { + srv.mu.Lock() + defer srv.mu.Unlock() + + if srv.RequestHandlers == nil { + srv.RequestHandlers = map[string]RequestHandler{} + for k, v := range DefaultRequestHandlers { + srv.RequestHandlers[k] = v + } + } + if srv.ChannelHandlers == nil { + srv.ChannelHandlers = map[string]ChannelHandler{} + for k, v := range DefaultChannelHandlers { + srv.ChannelHandlers[k] = v + } + } + if srv.SubsystemHandlers == nil { + srv.SubsystemHandlers = map[string]SubsystemHandler{} + for k, v := range DefaultSubsystemHandlers { + srv.SubsystemHandlers[k] = v + } + } +} + +func (srv *Server) config(ctx Context) *gossh.ServerConfig { + srv.mu.RLock() + defer srv.mu.RUnlock() + + var config *gossh.ServerConfig + if srv.ServerConfigCallback == nil { + config = &gossh.ServerConfig{} + } else { + config = srv.ServerConfigCallback(ctx) + } + for _, signer := range srv.HostSigners { + config.AddHostKey(signer) + } + if srv.PasswordHandler == nil && srv.PublicKeyHandler == nil && srv.KeyboardInteractiveHandler == nil { + config.NoClientAuth = true + } + if srv.Version != "" { + config.ServerVersion = "SSH-2.0-" + srv.Version + } + if srv.PasswordHandler != nil { + config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) { + applyConnMetadata(ctx, conn) + if ok := srv.PasswordHandler(ctx, string(password)); !ok { + return ctx.Permissions().Permissions, fmt.Errorf("permission denied") + } + return ctx.Permissions().Permissions, nil + } + } + if srv.PublicKeyHandler != nil { + config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) { + applyConnMetadata(ctx, conn) + if err := srv.PublicKeyHandler(ctx, key); err != nil { + return ctx.Permissions().Permissions, err + } + ctx.SetValue(ContextKeyPublicKey, key) + return ctx.Permissions().Permissions, nil + } + } + if srv.KeyboardInteractiveHandler != nil { + config.KeyboardInteractiveCallback = func(conn gossh.ConnMetadata, challenger gossh.KeyboardInteractiveChallenge) (*gossh.Permissions, error) { + applyConnMetadata(ctx, conn) + if ok := srv.KeyboardInteractiveHandler(ctx, challenger); !ok { + return ctx.Permissions().Permissions, fmt.Errorf("permission denied") + } + return ctx.Permissions().Permissions, nil + } + } + if srv.NoClientAuthHandler != nil { + config.NoClientAuthCallback = func(conn gossh.ConnMetadata) (*gossh.Permissions, error) { + applyConnMetadata(ctx, conn) + if err := srv.NoClientAuthHandler(ctx); err != nil { + return ctx.Permissions().Permissions, err + } + return ctx.Permissions().Permissions, nil + } + } + return config +} + +// Handle sets the Handler for the server. +func (srv *Server) Handle(fn Handler) { + srv.mu.Lock() + defer srv.mu.Unlock() + + srv.Handler = fn +} + +// Close immediately closes all active listeners and all active +// connections. +// +// Close returns any error returned from closing the Server's +// underlying Listener(s). +func (srv *Server) Close() error { + srv.mu.Lock() + defer srv.mu.Unlock() + + srv.closeDoneChanLocked() + err := srv.closeListenersLocked() + for c := range srv.conns { + c.Close() + delete(srv.conns, c) + } + return err +} + +// Shutdown gracefully shuts down the server without interrupting any +// active connections. Shutdown works by first closing all open +// listeners, and then waiting indefinitely for connections to close. +// If the provided context expires before the shutdown is complete, +// then the context's error is returned. +func (srv *Server) Shutdown(ctx context.Context) error { + srv.mu.Lock() + lnerr := srv.closeListenersLocked() + srv.closeDoneChanLocked() + srv.mu.Unlock() + + finished := make(chan struct{}, 1) + go func() { + srv.listenerWg.Wait() + srv.connWg.Wait() + finished <- struct{}{} + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-finished: + return lnerr + } +} + +// Serve accepts incoming connections on the Listener l, creating a new +// connection goroutine for each. The connection goroutines read requests and then +// calls srv.Handler to handle sessions. +// +// Serve always returns a non-nil error. +func (srv *Server) Serve(l net.Listener) error { + srv.ensureHandlers() + defer l.Close() + if err := srv.ensureHostSigner(); err != nil { + return err + } + if srv.Handler == nil { + srv.Handler = DefaultHandler + } + var tempDelay time.Duration + + srv.trackListener(l, true) + defer srv.trackListener(l, false) + for { + conn, e := l.Accept() + if e != nil { + select { + case <-srv.getDoneChan(): + return ErrServerClosed + default: + } + if ne, ok := e.(net.Error); ok && ne.Temporary() { + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 1 * time.Second; tempDelay > max { + tempDelay = max + } + time.Sleep(tempDelay) + continue + } + return e + } + go srv.HandleConn(conn) + } +} + +func (srv *Server) HandleConn(newConn net.Conn) { + ctx, cancel := newContext(srv) + if srv.ConnCallback != nil { + cbConn := srv.ConnCallback(ctx, newConn) + if cbConn == nil { + newConn.Close() + return + } + newConn = cbConn + } + conn := &serverConn{ + Conn: newConn, + idleTimeout: srv.IdleTimeout, + closeCanceler: cancel, + } + if srv.MaxTimeout > 0 { + conn.maxDeadline = time.Now().Add(srv.MaxTimeout) + } + defer conn.Close() + sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx)) + if err != nil { + if srv.ConnectionFailedCallback != nil { + srv.ConnectionFailedCallback(conn, err) + } + return + } + + srv.trackConn(sshConn, true) + defer srv.trackConn(sshConn, false) + + ctx.SetValue(ContextKeyConn, sshConn) + applyConnMetadata(ctx, sshConn) + //go gossh.DiscardRequests(reqs) + go srv.handleRequests(ctx, reqs) + for ch := range chans { + handler := srv.ChannelHandlers[ch.ChannelType()] + if handler == nil { + handler = srv.ChannelHandlers["default"] + } + if handler == nil { + ch.Reject(gossh.UnknownChannelType, "unsupported channel type") + continue + } + go handler(srv, sshConn, ch, ctx) + } +} + +func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) { + for req := range in { + handler := srv.RequestHandlers[req.Type] + if handler == nil { + handler = srv.RequestHandlers["default"] + } + if handler == nil { + req.Reply(false, nil) + continue + } + /*reqCtx, cancel := context.WithCancel(ctx) + defer cancel() */ + ret, payload := handler(ctx, srv, req) + req.Reply(ret, payload) + } +} + +// ListenAndServe listens on the TCP network address srv.Addr and then calls +// Serve to handle incoming connections. If srv.Addr is blank, ":22" is used. +// ListenAndServe always returns a non-nil error. +func (srv *Server) ListenAndServe() error { + addr := srv.Addr + if addr == "" { + addr = ":22" + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + return srv.Serve(ln) +} + +// AddHostKey adds a private key as a host key. If an existing host key exists +// with the same algorithm, it is overwritten. Each server config must have at +// least one host key. +func (srv *Server) AddHostKey(key Signer) { + srv.mu.Lock() + defer srv.mu.Unlock() + + // these are later added via AddHostKey on ServerConfig, which performs the + // check for one of every algorithm. + + // This check is based on the AddHostKey method from the x/crypto/ssh + // library. This allows us to only keep one active key for each type on a + // server at once. So, if you're dynamically updating keys at runtime, this + // list will not keep growing. + for i, k := range srv.HostSigners { + if k.PublicKey().Type() == key.PublicKey().Type() { + srv.HostSigners[i] = key + return + } + } + + srv.HostSigners = append(srv.HostSigners, key) +} + +// SetOption runs a functional option against the server. +func (srv *Server) SetOption(option Option) error { + // NOTE: there is a potential race here for any option that doesn't call an + // internal method. We can't actually lock here because if something calls + // (as an example) AddHostKey, it will deadlock. + + //srv.mu.Lock() + //defer srv.mu.Unlock() + + return option(srv) +} + +func (srv *Server) getDoneChan() <-chan struct{} { + srv.mu.Lock() + defer srv.mu.Unlock() + + return srv.getDoneChanLocked() +} + +func (srv *Server) getDoneChanLocked() chan struct{} { + if srv.doneChan == nil { + srv.doneChan = make(chan struct{}) + } + return srv.doneChan +} + +func (srv *Server) closeDoneChanLocked() { + ch := srv.getDoneChanLocked() + select { + case <-ch: + // Already closed. Don't close again. + default: + // Safe to close here. We're the only closer, guarded + // by srv.mu. + close(ch) + } +} + +func (srv *Server) closeListenersLocked() error { + var err error + for ln := range srv.listeners { + if cerr := ln.Close(); cerr != nil && err == nil { + err = cerr + } + delete(srv.listeners, ln) + } + return err +} + +func (srv *Server) trackListener(ln net.Listener, add bool) { + srv.mu.Lock() + defer srv.mu.Unlock() + + if srv.listeners == nil { + srv.listeners = make(map[net.Listener]struct{}) + } + if add { + // If the *Server is being reused after a previous + // Close or Shutdown, reset its doneChan: + if len(srv.listeners) == 0 && len(srv.conns) == 0 { + srv.doneChan = nil + } + srv.listeners[ln] = struct{}{} + srv.listenerWg.Add(1) + } else { + delete(srv.listeners, ln) + srv.listenerWg.Done() + } +} + +func (srv *Server) trackConn(c *gossh.ServerConn, add bool) { + srv.mu.Lock() + defer srv.mu.Unlock() + + if srv.conns == nil { + srv.conns = make(map[*gossh.ServerConn]struct{}) + } + if add { + srv.conns[c] = struct{}{} + srv.connWg.Add(1) + } else { + delete(srv.conns, c) + srv.connWg.Done() + } +} diff --git a/tempfork/gliderlabs/ssh/server_test.go b/tempfork/gliderlabs/ssh/server_test.go index 1a63bb4b2..177c07117 100644 --- a/tempfork/gliderlabs/ssh/server_test.go +++ b/tempfork/gliderlabs/ssh/server_test.go @@ -1,128 +1,128 @@ -//go:build glidertests - -package ssh - -import ( - "bytes" - "context" - "io" - "testing" - "time" -) - -func TestAddHostKey(t *testing.T) { - s := Server{} - signer, err := generateSigner() - if err != nil { - t.Fatal(err) - } - s.AddHostKey(signer) - if len(s.HostSigners) != 1 { - t.Fatal("Key was not properly added") - } - signer, err = generateSigner() - if err != nil { - t.Fatal(err) - } - s.AddHostKey(signer) - if len(s.HostSigners) != 1 { - t.Fatal("Key was not properly replaced") - } -} - -func TestServerShutdown(t *testing.T) { - l := newLocalListener() - testBytes := []byte("Hello world\n") - s := &Server{ - Handler: func(s Session) { - s.Write(testBytes) - time.Sleep(50 * time.Millisecond) - }, - } - go func() { - err := s.Serve(l) - if err != nil && err != ErrServerClosed { - t.Fatal(err) - } - }() - sessDone := make(chan struct{}) - sess, _, cleanup := newClientSession(t, l.Addr().String(), nil) - go func() { - defer cleanup() - defer close(sessDone) - var stdout bytes.Buffer - sess.Stdout = &stdout - if err := sess.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stdout.Bytes(), testBytes) { - t.Fatalf("expected = %s; got %s", testBytes, stdout.Bytes()) - } - }() - - srvDone := make(chan struct{}) - go func() { - defer close(srvDone) - err := s.Shutdown(context.Background()) - if err != nil { - t.Fatal(err) - } - }() - - timeout := time.After(2 * time.Second) - select { - case <-timeout: - t.Fatal("timeout") - return - case <-srvDone: - // TODO: add timeout for sessDone - <-sessDone - return - } -} - -func TestServerClose(t *testing.T) { - l := newLocalListener() - s := &Server{ - Handler: func(s Session) { - time.Sleep(5 * time.Second) - }, - } - go func() { - err := s.Serve(l) - if err != nil && err != ErrServerClosed { - t.Fatal(err) - } - }() - - clientDoneChan := make(chan struct{}) - closeDoneChan := make(chan struct{}) - - sess, _, cleanup := newClientSession(t, l.Addr().String(), nil) - go func() { - defer cleanup() - defer close(clientDoneChan) - <-closeDoneChan - if err := sess.Run(""); err != nil && err != io.EOF { - t.Fatal(err) - } - }() - - go func() { - err := s.Close() - if err != nil { - t.Fatal(err) - } - close(closeDoneChan) - }() - - timeout := time.After(100 * time.Millisecond) - select { - case <-timeout: - t.Error("timeout") - return - case <-s.getDoneChan(): - <-clientDoneChan - return - } -} +//go:build glidertests + +package ssh + +import ( + "bytes" + "context" + "io" + "testing" + "time" +) + +func TestAddHostKey(t *testing.T) { + s := Server{} + signer, err := generateSigner() + if err != nil { + t.Fatal(err) + } + s.AddHostKey(signer) + if len(s.HostSigners) != 1 { + t.Fatal("Key was not properly added") + } + signer, err = generateSigner() + if err != nil { + t.Fatal(err) + } + s.AddHostKey(signer) + if len(s.HostSigners) != 1 { + t.Fatal("Key was not properly replaced") + } +} + +func TestServerShutdown(t *testing.T) { + l := newLocalListener() + testBytes := []byte("Hello world\n") + s := &Server{ + Handler: func(s Session) { + s.Write(testBytes) + time.Sleep(50 * time.Millisecond) + }, + } + go func() { + err := s.Serve(l) + if err != nil && err != ErrServerClosed { + t.Fatal(err) + } + }() + sessDone := make(chan struct{}) + sess, _, cleanup := newClientSession(t, l.Addr().String(), nil) + go func() { + defer cleanup() + defer close(sessDone) + var stdout bytes.Buffer + sess.Stdout = &stdout + if err := sess.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stdout.Bytes(), testBytes) { + t.Fatalf("expected = %s; got %s", testBytes, stdout.Bytes()) + } + }() + + srvDone := make(chan struct{}) + go func() { + defer close(srvDone) + err := s.Shutdown(context.Background()) + if err != nil { + t.Fatal(err) + } + }() + + timeout := time.After(2 * time.Second) + select { + case <-timeout: + t.Fatal("timeout") + return + case <-srvDone: + // TODO: add timeout for sessDone + <-sessDone + return + } +} + +func TestServerClose(t *testing.T) { + l := newLocalListener() + s := &Server{ + Handler: func(s Session) { + time.Sleep(5 * time.Second) + }, + } + go func() { + err := s.Serve(l) + if err != nil && err != ErrServerClosed { + t.Fatal(err) + } + }() + + clientDoneChan := make(chan struct{}) + closeDoneChan := make(chan struct{}) + + sess, _, cleanup := newClientSession(t, l.Addr().String(), nil) + go func() { + defer cleanup() + defer close(clientDoneChan) + <-closeDoneChan + if err := sess.Run(""); err != nil && err != io.EOF { + t.Fatal(err) + } + }() + + go func() { + err := s.Close() + if err != nil { + t.Fatal(err) + } + close(closeDoneChan) + }() + + timeout := time.After(100 * time.Millisecond) + select { + case <-timeout: + t.Error("timeout") + return + case <-s.getDoneChan(): + <-clientDoneChan + return + } +} diff --git a/tempfork/gliderlabs/ssh/session.go b/tempfork/gliderlabs/ssh/session.go index 2f43de739..0a4a21e53 100644 --- a/tempfork/gliderlabs/ssh/session.go +++ b/tempfork/gliderlabs/ssh/session.go @@ -1,386 +1,386 @@ -package ssh - -import ( - "bytes" - "context" - "errors" - "fmt" - "net" - "sync" - - "github.com/anmitsu/go-shlex" - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -// Session provides access to information about an SSH session and methods -// to read and write to the SSH channel with an embedded Channel interface from -// crypto/ssh. -// -// When Command() returns an empty slice, the user requested a shell. Otherwise -// the user is performing an exec with those command arguments. -// -// TODO: Signals -type Session interface { - gossh.Channel - - // User returns the username used when establishing the SSH connection. - User() string - - // RemoteAddr returns the net.Addr of the client side of the connection. - RemoteAddr() net.Addr - - // LocalAddr returns the net.Addr of the server side of the connection. - LocalAddr() net.Addr - - // Environ returns a copy of strings representing the environment set by the - // user for this session, in the form "key=value". - Environ() []string - - // Exit sends an exit status and then closes the session. - Exit(code int) error - - // Command returns a shell parsed slice of arguments that were provided by the - // user. Shell parsing splits the command string according to POSIX shell rules, - // which considers quoting not just whitespace. - Command() []string - - // RawCommand returns the exact command that was provided by the user. - RawCommand() string - - // Subsystem returns the subsystem requested by the user. - Subsystem() string - - // PublicKey returns the PublicKey used to authenticate. If a public key was not - // used it will return nil. - PublicKey() PublicKey - - // Context returns the connection's context. The returned context is always - // non-nil and holds the same data as the Context passed into auth - // handlers and callbacks. - // - // The context is canceled when the client's connection closes or I/O - // operation fails. - Context() context.Context - - // Permissions returns a copy of the Permissions object that was available for - // setup in the auth handlers via the Context. - Permissions() Permissions - - // Pty returns PTY information, a channel of window size changes, and a boolean - // of whether or not a PTY was accepted for this session. - Pty() (Pty, <-chan Window, bool) - - // Signals registers a channel to receive signals sent from the client. The - // channel must handle signal sends or it will block the SSH request loop. - // Registering nil will unregister the channel from signal sends. During the - // time no channel is registered signals are buffered up to a reasonable amount. - // If there are buffered signals when a channel is registered, they will be - // sent in order on the channel immediately after registering. - Signals(c chan<- Signal) - - // Break regisers a channel to receive notifications of break requests sent - // from the client. The channel must handle break requests, or it will block - // the request handling loop. Registering nil will unregister the channel. - // During the time that no channel is registered, breaks are ignored. - Break(c chan<- bool) - - // DisablePTYEmulation disables the session's default minimal PTY emulation. - // If you're setting the pty's termios settings from the Pty request, use - // this method to avoid corruption. - // Currently (2022-03-12) the only emulation implemented is NL-to-CRNL translation (`\n`=>`\r\n`). - // A call of DisablePTYEmulation must precede any call to Write. - DisablePTYEmulation() -} - -// maxSigBufSize is how many signals will be buffered -// when there is no signal channel specified -const maxSigBufSize = 128 - -func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { - ch, reqs, err := newChan.Accept() - if err != nil { - // TODO: trigger event callback - return - } - sess := &session{ - Channel: ch, - conn: conn, - handler: srv.Handler, - ptyCb: srv.PtyCallback, - sessReqCb: srv.SessionRequestCallback, - subsystemHandlers: srv.SubsystemHandlers, - ctx: ctx, - } - sess.handleRequests(reqs) -} - -type session struct { - sync.Mutex - gossh.Channel - conn *gossh.ServerConn - handler Handler - subsystemHandlers map[string]SubsystemHandler - handled bool - exited bool - pty *Pty - winch chan Window - env []string - ptyCb PtyCallback - sessReqCb SessionRequestCallback - rawCmd string - subsystem string - ctx Context - sigCh chan<- Signal - sigBuf []Signal - breakCh chan<- bool - disablePtyEmulation bool -} - -func (sess *session) DisablePTYEmulation() { - sess.disablePtyEmulation = true -} - -func (sess *session) Write(p []byte) (n int, err error) { - if sess.pty != nil && !sess.disablePtyEmulation { - m := len(p) - // normalize \n to \r\n when pty is accepted. - // this is a hardcoded shortcut since we don't support terminal modes. - p = bytes.Replace(p, []byte{'\n'}, []byte{'\r', '\n'}, -1) - p = bytes.Replace(p, []byte{'\r', '\r', '\n'}, []byte{'\r', '\n'}, -1) - n, err = sess.Channel.Write(p) - if n > m { - n = m - } - return - } - return sess.Channel.Write(p) -} - -func (sess *session) PublicKey() PublicKey { - sessionkey := sess.ctx.Value(ContextKeyPublicKey) - if sessionkey == nil { - return nil - } - return sessionkey.(PublicKey) -} - -func (sess *session) Permissions() Permissions { - // use context permissions because its properly - // wrapped and easier to dereference - perms := sess.ctx.Value(ContextKeyPermissions).(*Permissions) - return *perms -} - -func (sess *session) Context() context.Context { - return sess.ctx -} - -func (sess *session) Exit(code int) error { - sess.Lock() - defer sess.Unlock() - if sess.exited { - return errors.New("Session.Exit called multiple times") - } - sess.exited = true - - status := struct{ Status uint32 }{uint32(code)} - _, err := sess.SendRequest("exit-status", false, gossh.Marshal(&status)) - if err != nil { - return err - } - return sess.Close() -} - -func (sess *session) User() string { - return sess.conn.User() -} - -func (sess *session) RemoteAddr() net.Addr { - return sess.conn.RemoteAddr() -} - -func (sess *session) LocalAddr() net.Addr { - return sess.conn.LocalAddr() -} - -func (sess *session) Environ() []string { - return append([]string(nil), sess.env...) -} - -func (sess *session) RawCommand() string { - return sess.rawCmd -} - -func (sess *session) Command() []string { - cmd, _ := shlex.Split(sess.rawCmd, true) - return append([]string(nil), cmd...) -} - -func (sess *session) Subsystem() string { - return sess.subsystem -} - -func (sess *session) Pty() (Pty, <-chan Window, bool) { - if sess.pty != nil { - return *sess.pty, sess.winch, true - } - return Pty{}, sess.winch, false -} - -func (sess *session) Signals(c chan<- Signal) { - sess.Lock() - defer sess.Unlock() - sess.sigCh = c - if len(sess.sigBuf) > 0 { - go func() { - for _, sig := range sess.sigBuf { - sess.sigCh <- sig - } - }() - } -} - -func (sess *session) Break(c chan<- bool) { - sess.Lock() - defer sess.Unlock() - sess.breakCh = c -} - -func (sess *session) handleRequests(reqs <-chan *gossh.Request) { - for req := range reqs { - switch req.Type { - case "shell", "exec": - if sess.handled { - req.Reply(false, nil) - continue - } - - var payload = struct{ Value string }{} - gossh.Unmarshal(req.Payload, &payload) - sess.rawCmd = payload.Value - - // If there's a session policy callback, we need to confirm before - // accepting the session. - if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { - sess.rawCmd = "" - req.Reply(false, nil) - continue - } - - sess.handled = true - req.Reply(true, nil) - - go func() { - sess.handler(sess) - sess.Exit(0) - }() - case "subsystem": - if sess.handled { - req.Reply(false, nil) - continue - } - - var payload = struct{ Value string }{} - gossh.Unmarshal(req.Payload, &payload) - sess.subsystem = payload.Value - - // If there's a session policy callback, we need to confirm before - // accepting the session. - if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { - sess.rawCmd = "" - req.Reply(false, nil) - continue - } - - handler := sess.subsystemHandlers[payload.Value] - if handler == nil { - handler = sess.subsystemHandlers["default"] - } - if handler == nil { - req.Reply(false, nil) - continue - } - - sess.handled = true - req.Reply(true, nil) - - go func() { - handler(sess) - sess.Exit(0) - }() - case "env": - if sess.handled { - req.Reply(false, nil) - continue - } - var kv struct{ Key, Value string } - gossh.Unmarshal(req.Payload, &kv) - sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value)) - req.Reply(true, nil) - case "signal": - var payload struct{ Signal string } - gossh.Unmarshal(req.Payload, &payload) - sess.Lock() - if sess.sigCh != nil { - sess.sigCh <- Signal(payload.Signal) - } else { - if len(sess.sigBuf) < maxSigBufSize { - sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal)) - } - } - sess.Unlock() - case "pty-req": - if sess.handled || sess.pty != nil { - req.Reply(false, nil) - continue - } - ptyReq, ok := parsePtyRequest(req.Payload) - if !ok { - req.Reply(false, nil) - continue - } - if sess.ptyCb != nil { - ok := sess.ptyCb(sess.ctx, ptyReq) - if !ok { - req.Reply(false, nil) - continue - } - } - sess.pty = &ptyReq - sess.winch = make(chan Window, 1) - sess.winch <- ptyReq.Window - defer func() { - // when reqs is closed - close(sess.winch) - }() - req.Reply(ok, nil) - case "window-change": - if sess.pty == nil { - req.Reply(false, nil) - continue - } - win, _, ok := parseWindow(req.Payload) - if ok { - sess.pty.Window = win - sess.winch <- win - } - req.Reply(ok, nil) - case agentRequestType: - // TODO: option/callback to allow agent forwarding - SetAgentRequested(sess.ctx) - req.Reply(true, nil) - case "break": - ok := false - sess.Lock() - if sess.breakCh != nil { - sess.breakCh <- true - ok = true - } - req.Reply(ok, nil) - sess.Unlock() - default: - // TODO: debug log - req.Reply(false, nil) - } - } -} +package ssh + +import ( + "bytes" + "context" + "errors" + "fmt" + "net" + "sync" + + "github.com/anmitsu/go-shlex" + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +// Session provides access to information about an SSH session and methods +// to read and write to the SSH channel with an embedded Channel interface from +// crypto/ssh. +// +// When Command() returns an empty slice, the user requested a shell. Otherwise +// the user is performing an exec with those command arguments. +// +// TODO: Signals +type Session interface { + gossh.Channel + + // User returns the username used when establishing the SSH connection. + User() string + + // RemoteAddr returns the net.Addr of the client side of the connection. + RemoteAddr() net.Addr + + // LocalAddr returns the net.Addr of the server side of the connection. + LocalAddr() net.Addr + + // Environ returns a copy of strings representing the environment set by the + // user for this session, in the form "key=value". + Environ() []string + + // Exit sends an exit status and then closes the session. + Exit(code int) error + + // Command returns a shell parsed slice of arguments that were provided by the + // user. Shell parsing splits the command string according to POSIX shell rules, + // which considers quoting not just whitespace. + Command() []string + + // RawCommand returns the exact command that was provided by the user. + RawCommand() string + + // Subsystem returns the subsystem requested by the user. + Subsystem() string + + // PublicKey returns the PublicKey used to authenticate. If a public key was not + // used it will return nil. + PublicKey() PublicKey + + // Context returns the connection's context. The returned context is always + // non-nil and holds the same data as the Context passed into auth + // handlers and callbacks. + // + // The context is canceled when the client's connection closes or I/O + // operation fails. + Context() context.Context + + // Permissions returns a copy of the Permissions object that was available for + // setup in the auth handlers via the Context. + Permissions() Permissions + + // Pty returns PTY information, a channel of window size changes, and a boolean + // of whether or not a PTY was accepted for this session. + Pty() (Pty, <-chan Window, bool) + + // Signals registers a channel to receive signals sent from the client. The + // channel must handle signal sends or it will block the SSH request loop. + // Registering nil will unregister the channel from signal sends. During the + // time no channel is registered signals are buffered up to a reasonable amount. + // If there are buffered signals when a channel is registered, they will be + // sent in order on the channel immediately after registering. + Signals(c chan<- Signal) + + // Break regisers a channel to receive notifications of break requests sent + // from the client. The channel must handle break requests, or it will block + // the request handling loop. Registering nil will unregister the channel. + // During the time that no channel is registered, breaks are ignored. + Break(c chan<- bool) + + // DisablePTYEmulation disables the session's default minimal PTY emulation. + // If you're setting the pty's termios settings from the Pty request, use + // this method to avoid corruption. + // Currently (2022-03-12) the only emulation implemented is NL-to-CRNL translation (`\n`=>`\r\n`). + // A call of DisablePTYEmulation must precede any call to Write. + DisablePTYEmulation() +} + +// maxSigBufSize is how many signals will be buffered +// when there is no signal channel specified +const maxSigBufSize = 128 + +func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { + ch, reqs, err := newChan.Accept() + if err != nil { + // TODO: trigger event callback + return + } + sess := &session{ + Channel: ch, + conn: conn, + handler: srv.Handler, + ptyCb: srv.PtyCallback, + sessReqCb: srv.SessionRequestCallback, + subsystemHandlers: srv.SubsystemHandlers, + ctx: ctx, + } + sess.handleRequests(reqs) +} + +type session struct { + sync.Mutex + gossh.Channel + conn *gossh.ServerConn + handler Handler + subsystemHandlers map[string]SubsystemHandler + handled bool + exited bool + pty *Pty + winch chan Window + env []string + ptyCb PtyCallback + sessReqCb SessionRequestCallback + rawCmd string + subsystem string + ctx Context + sigCh chan<- Signal + sigBuf []Signal + breakCh chan<- bool + disablePtyEmulation bool +} + +func (sess *session) DisablePTYEmulation() { + sess.disablePtyEmulation = true +} + +func (sess *session) Write(p []byte) (n int, err error) { + if sess.pty != nil && !sess.disablePtyEmulation { + m := len(p) + // normalize \n to \r\n when pty is accepted. + // this is a hardcoded shortcut since we don't support terminal modes. + p = bytes.Replace(p, []byte{'\n'}, []byte{'\r', '\n'}, -1) + p = bytes.Replace(p, []byte{'\r', '\r', '\n'}, []byte{'\r', '\n'}, -1) + n, err = sess.Channel.Write(p) + if n > m { + n = m + } + return + } + return sess.Channel.Write(p) +} + +func (sess *session) PublicKey() PublicKey { + sessionkey := sess.ctx.Value(ContextKeyPublicKey) + if sessionkey == nil { + return nil + } + return sessionkey.(PublicKey) +} + +func (sess *session) Permissions() Permissions { + // use context permissions because its properly + // wrapped and easier to dereference + perms := sess.ctx.Value(ContextKeyPermissions).(*Permissions) + return *perms +} + +func (sess *session) Context() context.Context { + return sess.ctx +} + +func (sess *session) Exit(code int) error { + sess.Lock() + defer sess.Unlock() + if sess.exited { + return errors.New("Session.Exit called multiple times") + } + sess.exited = true + + status := struct{ Status uint32 }{uint32(code)} + _, err := sess.SendRequest("exit-status", false, gossh.Marshal(&status)) + if err != nil { + return err + } + return sess.Close() +} + +func (sess *session) User() string { + return sess.conn.User() +} + +func (sess *session) RemoteAddr() net.Addr { + return sess.conn.RemoteAddr() +} + +func (sess *session) LocalAddr() net.Addr { + return sess.conn.LocalAddr() +} + +func (sess *session) Environ() []string { + return append([]string(nil), sess.env...) +} + +func (sess *session) RawCommand() string { + return sess.rawCmd +} + +func (sess *session) Command() []string { + cmd, _ := shlex.Split(sess.rawCmd, true) + return append([]string(nil), cmd...) +} + +func (sess *session) Subsystem() string { + return sess.subsystem +} + +func (sess *session) Pty() (Pty, <-chan Window, bool) { + if sess.pty != nil { + return *sess.pty, sess.winch, true + } + return Pty{}, sess.winch, false +} + +func (sess *session) Signals(c chan<- Signal) { + sess.Lock() + defer sess.Unlock() + sess.sigCh = c + if len(sess.sigBuf) > 0 { + go func() { + for _, sig := range sess.sigBuf { + sess.sigCh <- sig + } + }() + } +} + +func (sess *session) Break(c chan<- bool) { + sess.Lock() + defer sess.Unlock() + sess.breakCh = c +} + +func (sess *session) handleRequests(reqs <-chan *gossh.Request) { + for req := range reqs { + switch req.Type { + case "shell", "exec": + if sess.handled { + req.Reply(false, nil) + continue + } + + var payload = struct{ Value string }{} + gossh.Unmarshal(req.Payload, &payload) + sess.rawCmd = payload.Value + + // If there's a session policy callback, we need to confirm before + // accepting the session. + if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { + sess.rawCmd = "" + req.Reply(false, nil) + continue + } + + sess.handled = true + req.Reply(true, nil) + + go func() { + sess.handler(sess) + sess.Exit(0) + }() + case "subsystem": + if sess.handled { + req.Reply(false, nil) + continue + } + + var payload = struct{ Value string }{} + gossh.Unmarshal(req.Payload, &payload) + sess.subsystem = payload.Value + + // If there's a session policy callback, we need to confirm before + // accepting the session. + if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { + sess.rawCmd = "" + req.Reply(false, nil) + continue + } + + handler := sess.subsystemHandlers[payload.Value] + if handler == nil { + handler = sess.subsystemHandlers["default"] + } + if handler == nil { + req.Reply(false, nil) + continue + } + + sess.handled = true + req.Reply(true, nil) + + go func() { + handler(sess) + sess.Exit(0) + }() + case "env": + if sess.handled { + req.Reply(false, nil) + continue + } + var kv struct{ Key, Value string } + gossh.Unmarshal(req.Payload, &kv) + sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value)) + req.Reply(true, nil) + case "signal": + var payload struct{ Signal string } + gossh.Unmarshal(req.Payload, &payload) + sess.Lock() + if sess.sigCh != nil { + sess.sigCh <- Signal(payload.Signal) + } else { + if len(sess.sigBuf) < maxSigBufSize { + sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal)) + } + } + sess.Unlock() + case "pty-req": + if sess.handled || sess.pty != nil { + req.Reply(false, nil) + continue + } + ptyReq, ok := parsePtyRequest(req.Payload) + if !ok { + req.Reply(false, nil) + continue + } + if sess.ptyCb != nil { + ok := sess.ptyCb(sess.ctx, ptyReq) + if !ok { + req.Reply(false, nil) + continue + } + } + sess.pty = &ptyReq + sess.winch = make(chan Window, 1) + sess.winch <- ptyReq.Window + defer func() { + // when reqs is closed + close(sess.winch) + }() + req.Reply(ok, nil) + case "window-change": + if sess.pty == nil { + req.Reply(false, nil) + continue + } + win, _, ok := parseWindow(req.Payload) + if ok { + sess.pty.Window = win + sess.winch <- win + } + req.Reply(ok, nil) + case agentRequestType: + // TODO: option/callback to allow agent forwarding + SetAgentRequested(sess.ctx) + req.Reply(true, nil) + case "break": + ok := false + sess.Lock() + if sess.breakCh != nil { + sess.breakCh <- true + ok = true + } + req.Reply(ok, nil) + sess.Unlock() + default: + // TODO: debug log + req.Reply(false, nil) + } + } +} diff --git a/tempfork/gliderlabs/ssh/session_test.go b/tempfork/gliderlabs/ssh/session_test.go index fddd67f6d..a60be5ec1 100644 --- a/tempfork/gliderlabs/ssh/session_test.go +++ b/tempfork/gliderlabs/ssh/session_test.go @@ -1,440 +1,440 @@ -//go:build glidertests - -package ssh - -import ( - "bytes" - "fmt" - "io" - "net" - "testing" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -func (srv *Server) serveOnce(l net.Listener) error { - srv.ensureHandlers() - if err := srv.ensureHostSigner(); err != nil { - return err - } - conn, e := l.Accept() - if e != nil { - return e - } - srv.ChannelHandlers = map[string]ChannelHandler{ - "session": DefaultSessionHandler, - "direct-tcpip": DirectTCPIPHandler, - } - srv.HandleConn(conn) - return nil -} - -func newLocalListener() net.Listener { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { - panic(fmt.Sprintf("failed to listen on a port: %v", err)) - } - } - return l -} - -func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { - if config == nil { - config = &gossh.ClientConfig{ - User: "testuser", - Auth: []gossh.AuthMethod{ - gossh.Password("testpass"), - }, - } - } - if config.HostKeyCallback == nil { - config.HostKeyCallback = gossh.InsecureIgnoreHostKey() - } - client, err := gossh.Dial("tcp", addr, config) - if err != nil { - t.Fatal(err) - } - session, err := client.NewSession() - if err != nil { - t.Fatal(err) - } - return session, client, func() { - session.Close() - client.Close() - } -} - -func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { - l := newLocalListener() - go srv.serveOnce(l) - return newClientSession(t, l.Addr().String(), cfg) -} - -func TestStdout(t *testing.T) { - t.Parallel() - testBytes := []byte("Hello world\n") - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Write(testBytes) - }, - }, nil) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - if err := session.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stdout.Bytes(), testBytes) { - t.Fatalf("stdout = %#v; want %#v", stdout.Bytes(), testBytes) - } -} - -func TestStderr(t *testing.T) { - t.Parallel() - testBytes := []byte("Hello world\n") - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Stderr().Write(testBytes) - }, - }, nil) - defer cleanup() - var stderr bytes.Buffer - session.Stderr = &stderr - if err := session.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stderr.Bytes(), testBytes) { - t.Fatalf("stderr = %#v; want %#v", stderr.Bytes(), testBytes) - } -} - -func TestStdin(t *testing.T) { - t.Parallel() - testBytes := []byte("Hello world\n") - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - io.Copy(s, s) // stdin back into stdout - }, - }, nil) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - session.Stdin = bytes.NewBuffer(testBytes) - if err := session.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stdout.Bytes(), testBytes) { - t.Fatalf("stdout = %#v; want %#v given stdin = %#v", stdout.Bytes(), testBytes, testBytes) - } -} - -func TestUser(t *testing.T) { - t.Parallel() - testUser := []byte("progrium") - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - io.WriteString(s, s.User()) - }, - }, &gossh.ClientConfig{ - User: string(testUser), - }) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - if err := session.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stdout.Bytes(), testUser) { - t.Fatalf("stdout = %#v; want %#v given user = %#v", stdout.Bytes(), testUser, string(testUser)) - } -} - -func TestDefaultExitStatusZero(t *testing.T) { - t.Parallel() - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - // noop - }, - }, nil) - defer cleanup() - err := session.Run("") - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} - -func TestExplicitExitStatusZero(t *testing.T) { - t.Parallel() - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Exit(0) - }, - }, nil) - defer cleanup() - err := session.Run("") - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} - -func TestExitStatusNonZero(t *testing.T) { - t.Parallel() - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Exit(1) - }, - }, nil) - defer cleanup() - err := session.Run("") - e, ok := err.(*gossh.ExitError) - if !ok { - t.Fatalf("expected ExitError but got %T", err) - } - if e.ExitStatus() != 1 { - t.Fatalf("exit-status = %#v; want %#v", e.ExitStatus(), 1) - } -} - -func TestPty(t *testing.T) { - t.Parallel() - term := "xterm" - winWidth := 40 - winHeight := 80 - done := make(chan bool) - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - ptyReq, _, isPty := s.Pty() - if !isPty { - t.Fatalf("expected pty but none requested") - } - if ptyReq.Term != term { - t.Fatalf("expected term %#v but got %#v", term, ptyReq.Term) - } - if ptyReq.Window.Width != winWidth { - t.Fatalf("expected window width %#v but got %#v", winWidth, ptyReq.Window.Width) - } - if ptyReq.Window.Height != winHeight { - t.Fatalf("expected window height %#v but got %#v", winHeight, ptyReq.Window.Height) - } - close(done) - }, - }, nil) - defer cleanup() - if err := session.RequestPty(term, winHeight, winWidth, gossh.TerminalModes{}); err != nil { - t.Fatalf("expected nil but got %v", err) - } - if err := session.Shell(); err != nil { - t.Fatalf("expected nil but got %v", err) - } - <-done -} - -func TestPtyResize(t *testing.T) { - t.Parallel() - winch0 := Window{Width: 40, Height: 80} - winch1 := Window{Width: 80, Height: 160} - winch2 := Window{Width: 20, Height: 40} - winches := make(chan Window) - done := make(chan bool) - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - ptyReq, winCh, isPty := s.Pty() - if !isPty { - t.Fatalf("expected pty but none requested") - } - if ptyReq.Window != winch0 { - t.Fatalf("expected window %#v but got %#v", winch0, ptyReq.Window) - } - for win := range winCh { - winches <- win - } - close(done) - }, - }, nil) - defer cleanup() - // winch0 - if err := session.RequestPty("xterm", winch0.Height, winch0.Width, gossh.TerminalModes{}); err != nil { - t.Fatalf("expected nil but got %v", err) - } - if err := session.Shell(); err != nil { - t.Fatalf("expected nil but got %v", err) - } - gotWinch := <-winches - if gotWinch != winch0 { - t.Fatalf("expected window %#v but got %#v", winch0, gotWinch) - } - // winch1 - winchMsg := struct{ w, h uint32 }{uint32(winch1.Width), uint32(winch1.Height)} - ok, err := session.SendRequest("window-change", true, gossh.Marshal(&winchMsg)) - if err == nil && !ok { - t.Fatalf("unexpected error or bad reply on send request") - } - gotWinch = <-winches - if gotWinch != winch1 { - t.Fatalf("expected window %#v but got %#v", winch1, gotWinch) - } - // winch2 - winchMsg = struct{ w, h uint32 }{uint32(winch2.Width), uint32(winch2.Height)} - ok, err = session.SendRequest("window-change", true, gossh.Marshal(&winchMsg)) - if err == nil && !ok { - t.Fatalf("unexpected error or bad reply on send request") - } - gotWinch = <-winches - if gotWinch != winch2 { - t.Fatalf("expected window %#v but got %#v", winch2, gotWinch) - } - session.Close() - <-done -} - -func TestSignals(t *testing.T) { - t.Parallel() - - // errChan lets us get errors back from the session - errChan := make(chan error, 5) - - // doneChan lets us specify that we should exit. - doneChan := make(chan interface{}) - - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - // We need to use a buffered channel here, otherwise it's possible for the - // second call to Signal to get discarded. - signals := make(chan Signal, 2) - s.Signals(signals) - - select { - case sig := <-signals: - if sig != SIGINT { - errChan <- fmt.Errorf("expected signal %v but got %v", SIGINT, sig) - return - } - case <-doneChan: - errChan <- fmt.Errorf("Unexpected done") - return - } - - select { - case sig := <-signals: - if sig != SIGKILL { - errChan <- fmt.Errorf("expected signal %v but got %v", SIGKILL, sig) - return - } - case <-doneChan: - errChan <- fmt.Errorf("Unexpected done") - return - } - }, - }, nil) - defer cleanup() - - go func() { - session.Signal(gossh.SIGINT) - session.Signal(gossh.SIGKILL) - }() - - go func() { - errChan <- session.Run("") - }() - - err := <-errChan - close(doneChan) - - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} - -func TestBreakWithChanRegistered(t *testing.T) { - t.Parallel() - - // errChan lets us get errors back from the session - errChan := make(chan error, 5) - - // doneChan lets us specify that we should exit. - doneChan := make(chan interface{}) - - breakChan := make(chan bool) - - readyToReceiveBreak := make(chan bool) - - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Break(breakChan) // register a break channel with the session - readyToReceiveBreak <- true - - select { - case <-breakChan: - io.WriteString(s, "break") - case <-doneChan: - errChan <- fmt.Errorf("Unexpected done") - return - } - }, - }, nil) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - go func() { - errChan <- session.Run("") - }() - - <-readyToReceiveBreak - ok, err := session.SendRequest("break", true, nil) - if err != nil { - t.Fatalf("expected nil but got %v", err) - } - if ok != true { - t.Fatalf("expected true but got %v", ok) - } - - err = <-errChan - close(doneChan) - - if err != nil { - t.Fatalf("expected nil but got %v", err) - } - if !bytes.Equal(stdout.Bytes(), []byte("break")) { - t.Fatalf("stdout = %#v, expected 'break'", stdout.Bytes()) - } -} - -func TestBreakWithoutChanRegistered(t *testing.T) { - t.Parallel() - - // errChan lets us get errors back from the session - errChan := make(chan error, 5) - - // doneChan lets us specify that we should exit. - doneChan := make(chan interface{}) - - waitUntilAfterBreakSent := make(chan bool) - - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - <-waitUntilAfterBreakSent - }, - }, nil) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - go func() { - errChan <- session.Run("") - }() - - ok, err := session.SendRequest("break", true, nil) - if err != nil { - t.Fatalf("expected nil but got %v", err) - } - if ok != false { - t.Fatalf("expected false but got %v", ok) - } - waitUntilAfterBreakSent <- true - - err = <-errChan - close(doneChan) - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} +//go:build glidertests + +package ssh + +import ( + "bytes" + "fmt" + "io" + "net" + "testing" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +func (srv *Server) serveOnce(l net.Listener) error { + srv.ensureHandlers() + if err := srv.ensureHostSigner(); err != nil { + return err + } + conn, e := l.Accept() + if e != nil { + return e + } + srv.ChannelHandlers = map[string]ChannelHandler{ + "session": DefaultSessionHandler, + "direct-tcpip": DirectTCPIPHandler, + } + srv.HandleConn(conn) + return nil +} + +func newLocalListener() net.Listener { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { + panic(fmt.Sprintf("failed to listen on a port: %v", err)) + } + } + return l +} + +func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { + if config == nil { + config = &gossh.ClientConfig{ + User: "testuser", + Auth: []gossh.AuthMethod{ + gossh.Password("testpass"), + }, + } + } + if config.HostKeyCallback == nil { + config.HostKeyCallback = gossh.InsecureIgnoreHostKey() + } + client, err := gossh.Dial("tcp", addr, config) + if err != nil { + t.Fatal(err) + } + session, err := client.NewSession() + if err != nil { + t.Fatal(err) + } + return session, client, func() { + session.Close() + client.Close() + } +} + +func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { + l := newLocalListener() + go srv.serveOnce(l) + return newClientSession(t, l.Addr().String(), cfg) +} + +func TestStdout(t *testing.T) { + t.Parallel() + testBytes := []byte("Hello world\n") + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Write(testBytes) + }, + }, nil) + defer cleanup() + var stdout bytes.Buffer + session.Stdout = &stdout + if err := session.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stdout.Bytes(), testBytes) { + t.Fatalf("stdout = %#v; want %#v", stdout.Bytes(), testBytes) + } +} + +func TestStderr(t *testing.T) { + t.Parallel() + testBytes := []byte("Hello world\n") + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Stderr().Write(testBytes) + }, + }, nil) + defer cleanup() + var stderr bytes.Buffer + session.Stderr = &stderr + if err := session.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stderr.Bytes(), testBytes) { + t.Fatalf("stderr = %#v; want %#v", stderr.Bytes(), testBytes) + } +} + +func TestStdin(t *testing.T) { + t.Parallel() + testBytes := []byte("Hello world\n") + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + io.Copy(s, s) // stdin back into stdout + }, + }, nil) + defer cleanup() + var stdout bytes.Buffer + session.Stdout = &stdout + session.Stdin = bytes.NewBuffer(testBytes) + if err := session.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stdout.Bytes(), testBytes) { + t.Fatalf("stdout = %#v; want %#v given stdin = %#v", stdout.Bytes(), testBytes, testBytes) + } +} + +func TestUser(t *testing.T) { + t.Parallel() + testUser := []byte("progrium") + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + io.WriteString(s, s.User()) + }, + }, &gossh.ClientConfig{ + User: string(testUser), + }) + defer cleanup() + var stdout bytes.Buffer + session.Stdout = &stdout + if err := session.Run(""); err != nil { + t.Fatal(err) + } + if !bytes.Equal(stdout.Bytes(), testUser) { + t.Fatalf("stdout = %#v; want %#v given user = %#v", stdout.Bytes(), testUser, string(testUser)) + } +} + +func TestDefaultExitStatusZero(t *testing.T) { + t.Parallel() + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + // noop + }, + }, nil) + defer cleanup() + err := session.Run("") + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} + +func TestExplicitExitStatusZero(t *testing.T) { + t.Parallel() + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Exit(0) + }, + }, nil) + defer cleanup() + err := session.Run("") + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} + +func TestExitStatusNonZero(t *testing.T) { + t.Parallel() + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Exit(1) + }, + }, nil) + defer cleanup() + err := session.Run("") + e, ok := err.(*gossh.ExitError) + if !ok { + t.Fatalf("expected ExitError but got %T", err) + } + if e.ExitStatus() != 1 { + t.Fatalf("exit-status = %#v; want %#v", e.ExitStatus(), 1) + } +} + +func TestPty(t *testing.T) { + t.Parallel() + term := "xterm" + winWidth := 40 + winHeight := 80 + done := make(chan bool) + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + ptyReq, _, isPty := s.Pty() + if !isPty { + t.Fatalf("expected pty but none requested") + } + if ptyReq.Term != term { + t.Fatalf("expected term %#v but got %#v", term, ptyReq.Term) + } + if ptyReq.Window.Width != winWidth { + t.Fatalf("expected window width %#v but got %#v", winWidth, ptyReq.Window.Width) + } + if ptyReq.Window.Height != winHeight { + t.Fatalf("expected window height %#v but got %#v", winHeight, ptyReq.Window.Height) + } + close(done) + }, + }, nil) + defer cleanup() + if err := session.RequestPty(term, winHeight, winWidth, gossh.TerminalModes{}); err != nil { + t.Fatalf("expected nil but got %v", err) + } + if err := session.Shell(); err != nil { + t.Fatalf("expected nil but got %v", err) + } + <-done +} + +func TestPtyResize(t *testing.T) { + t.Parallel() + winch0 := Window{Width: 40, Height: 80} + winch1 := Window{Width: 80, Height: 160} + winch2 := Window{Width: 20, Height: 40} + winches := make(chan Window) + done := make(chan bool) + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + ptyReq, winCh, isPty := s.Pty() + if !isPty { + t.Fatalf("expected pty but none requested") + } + if ptyReq.Window != winch0 { + t.Fatalf("expected window %#v but got %#v", winch0, ptyReq.Window) + } + for win := range winCh { + winches <- win + } + close(done) + }, + }, nil) + defer cleanup() + // winch0 + if err := session.RequestPty("xterm", winch0.Height, winch0.Width, gossh.TerminalModes{}); err != nil { + t.Fatalf("expected nil but got %v", err) + } + if err := session.Shell(); err != nil { + t.Fatalf("expected nil but got %v", err) + } + gotWinch := <-winches + if gotWinch != winch0 { + t.Fatalf("expected window %#v but got %#v", winch0, gotWinch) + } + // winch1 + winchMsg := struct{ w, h uint32 }{uint32(winch1.Width), uint32(winch1.Height)} + ok, err := session.SendRequest("window-change", true, gossh.Marshal(&winchMsg)) + if err == nil && !ok { + t.Fatalf("unexpected error or bad reply on send request") + } + gotWinch = <-winches + if gotWinch != winch1 { + t.Fatalf("expected window %#v but got %#v", winch1, gotWinch) + } + // winch2 + winchMsg = struct{ w, h uint32 }{uint32(winch2.Width), uint32(winch2.Height)} + ok, err = session.SendRequest("window-change", true, gossh.Marshal(&winchMsg)) + if err == nil && !ok { + t.Fatalf("unexpected error or bad reply on send request") + } + gotWinch = <-winches + if gotWinch != winch2 { + t.Fatalf("expected window %#v but got %#v", winch2, gotWinch) + } + session.Close() + <-done +} + +func TestSignals(t *testing.T) { + t.Parallel() + + // errChan lets us get errors back from the session + errChan := make(chan error, 5) + + // doneChan lets us specify that we should exit. + doneChan := make(chan interface{}) + + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + // We need to use a buffered channel here, otherwise it's possible for the + // second call to Signal to get discarded. + signals := make(chan Signal, 2) + s.Signals(signals) + + select { + case sig := <-signals: + if sig != SIGINT { + errChan <- fmt.Errorf("expected signal %v but got %v", SIGINT, sig) + return + } + case <-doneChan: + errChan <- fmt.Errorf("Unexpected done") + return + } + + select { + case sig := <-signals: + if sig != SIGKILL { + errChan <- fmt.Errorf("expected signal %v but got %v", SIGKILL, sig) + return + } + case <-doneChan: + errChan <- fmt.Errorf("Unexpected done") + return + } + }, + }, nil) + defer cleanup() + + go func() { + session.Signal(gossh.SIGINT) + session.Signal(gossh.SIGKILL) + }() + + go func() { + errChan <- session.Run("") + }() + + err := <-errChan + close(doneChan) + + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} + +func TestBreakWithChanRegistered(t *testing.T) { + t.Parallel() + + // errChan lets us get errors back from the session + errChan := make(chan error, 5) + + // doneChan lets us specify that we should exit. + doneChan := make(chan interface{}) + + breakChan := make(chan bool) + + readyToReceiveBreak := make(chan bool) + + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + s.Break(breakChan) // register a break channel with the session + readyToReceiveBreak <- true + + select { + case <-breakChan: + io.WriteString(s, "break") + case <-doneChan: + errChan <- fmt.Errorf("Unexpected done") + return + } + }, + }, nil) + defer cleanup() + var stdout bytes.Buffer + session.Stdout = &stdout + go func() { + errChan <- session.Run("") + }() + + <-readyToReceiveBreak + ok, err := session.SendRequest("break", true, nil) + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + if ok != true { + t.Fatalf("expected true but got %v", ok) + } + + err = <-errChan + close(doneChan) + + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + if !bytes.Equal(stdout.Bytes(), []byte("break")) { + t.Fatalf("stdout = %#v, expected 'break'", stdout.Bytes()) + } +} + +func TestBreakWithoutChanRegistered(t *testing.T) { + t.Parallel() + + // errChan lets us get errors back from the session + errChan := make(chan error, 5) + + // doneChan lets us specify that we should exit. + doneChan := make(chan interface{}) + + waitUntilAfterBreakSent := make(chan bool) + + session, _, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) { + <-waitUntilAfterBreakSent + }, + }, nil) + defer cleanup() + var stdout bytes.Buffer + session.Stdout = &stdout + go func() { + errChan <- session.Run("") + }() + + ok, err := session.SendRequest("break", true, nil) + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + if ok != false { + t.Fatalf("expected false but got %v", ok) + } + waitUntilAfterBreakSent <- true + + err = <-errChan + close(doneChan) + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} diff --git a/tempfork/gliderlabs/ssh/ssh.go b/tempfork/gliderlabs/ssh/ssh.go index 4216ea97a..644cb257d 100644 --- a/tempfork/gliderlabs/ssh/ssh.go +++ b/tempfork/gliderlabs/ssh/ssh.go @@ -1,156 +1,156 @@ -package ssh - -import ( - "crypto/subtle" - "net" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -type Signal string - -// POSIX signals as listed in RFC 4254 Section 6.10. -const ( - SIGABRT Signal = "ABRT" - SIGALRM Signal = "ALRM" - SIGFPE Signal = "FPE" - SIGHUP Signal = "HUP" - SIGILL Signal = "ILL" - SIGINT Signal = "INT" - SIGKILL Signal = "KILL" - SIGPIPE Signal = "PIPE" - SIGQUIT Signal = "QUIT" - SIGSEGV Signal = "SEGV" - SIGTERM Signal = "TERM" - SIGUSR1 Signal = "USR1" - SIGUSR2 Signal = "USR2" -) - -// DefaultHandler is the default Handler used by Serve. -var DefaultHandler Handler - -// Option is a functional option handler for Server. -type Option func(*Server) error - -// Handler is a callback for handling established SSH sessions. -type Handler func(Session) - -// PublicKeyHandler is a callback for performing public key authentication. -type PublicKeyHandler func(ctx Context, key PublicKey) error - -type NoClientAuthHandler func(ctx Context) error - -type BannerHandler func(ctx Context) string - -// PasswordHandler is a callback for performing password authentication. -type PasswordHandler func(ctx Context, password string) bool - -// KeyboardInteractiveHandler is a callback for performing keyboard-interactive authentication. -type KeyboardInteractiveHandler func(ctx Context, challenger gossh.KeyboardInteractiveChallenge) bool - -// PtyCallback is a hook for allowing PTY sessions. -type PtyCallback func(ctx Context, pty Pty) bool - -// SessionRequestCallback is a callback for allowing or denying SSH sessions. -type SessionRequestCallback func(sess Session, requestType string) bool - -// ConnCallback is a hook for new connections before handling. -// It allows wrapping for timeouts and limiting by returning -// the net.Conn that will be used as the underlying connection. -type ConnCallback func(ctx Context, conn net.Conn) net.Conn - -// LocalPortForwardingCallback is a hook for allowing port forwarding -type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool - -// ReversePortForwardingCallback is a hook for allowing reverse port forwarding -type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool - -// ServerConfigCallback is a hook for creating custom default server configs -type ServerConfigCallback func(ctx Context) *gossh.ServerConfig - -// ConnectionFailedCallback is a hook for reporting failed connections -// Please note: the net.Conn is likely to be closed at this point -type ConnectionFailedCallback func(conn net.Conn, err error) - -// Window represents the size of a PTY window. -// -// See https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 -// -// Zero dimension parameters MUST be ignored. The character/row dimensions -// override the pixel dimensions (when nonzero). Pixel dimensions refer -// to the drawable area of the window. -type Window struct { - // Width is the number of columns. - // It overrides WidthPixels. - Width int - // Height is the number of rows. - // It overrides HeightPixels. - Height int - - // WidthPixels is the drawable width of the window, in pixels. - WidthPixels int - // HeightPixels is the drawable height of the window, in pixels. - HeightPixels int -} - -// Pty represents a PTY request and configuration. -type Pty struct { - // Term is the TERM environment variable value. - Term string - - // Window is the Window sent as part of the pty-req. - Window Window - - // Modes represent a mapping of Terminal Mode opcode to value as it was - // requested by the client as part of the pty-req. These are outlined as - // part of https://datatracker.ietf.org/doc/html/rfc4254#section-8. - // - // The opcodes are defined as constants in github.com/tailscale/golang-x-crypto/ssh (VINTR,VQUIT,etc.). - // Boolean opcodes have values 0 or 1. - Modes gossh.TerminalModes -} - -// Serve accepts incoming SSH connections on the listener l, creating a new -// connection goroutine for each. The connection goroutines read requests and -// then calls handler to handle sessions. Handler is typically nil, in which -// case the DefaultHandler is used. -func Serve(l net.Listener, handler Handler, options ...Option) error { - srv := &Server{Handler: handler} - for _, option := range options { - if err := srv.SetOption(option); err != nil { - return err - } - } - return srv.Serve(l) -} - -// ListenAndServe listens on the TCP network address addr and then calls Serve -// with handler to handle sessions on incoming connections. Handler is typically -// nil, in which case the DefaultHandler is used. -func ListenAndServe(addr string, handler Handler, options ...Option) error { - srv := &Server{Addr: addr, Handler: handler} - for _, option := range options { - if err := srv.SetOption(option); err != nil { - return err - } - } - return srv.ListenAndServe() -} - -// Handle registers the handler as the DefaultHandler. -func Handle(handler Handler) { - DefaultHandler = handler -} - -// KeysEqual is constant time compare of the keys to avoid timing attacks. -func KeysEqual(ak, bk PublicKey) bool { - - //avoid panic if one of the keys is nil, return false instead - if ak == nil || bk == nil { - return false - } - - a := ak.Marshal() - b := bk.Marshal() - return (len(a) == len(b) && subtle.ConstantTimeCompare(a, b) == 1) -} +package ssh + +import ( + "crypto/subtle" + "net" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +type Signal string + +// POSIX signals as listed in RFC 4254 Section 6.10. +const ( + SIGABRT Signal = "ABRT" + SIGALRM Signal = "ALRM" + SIGFPE Signal = "FPE" + SIGHUP Signal = "HUP" + SIGILL Signal = "ILL" + SIGINT Signal = "INT" + SIGKILL Signal = "KILL" + SIGPIPE Signal = "PIPE" + SIGQUIT Signal = "QUIT" + SIGSEGV Signal = "SEGV" + SIGTERM Signal = "TERM" + SIGUSR1 Signal = "USR1" + SIGUSR2 Signal = "USR2" +) + +// DefaultHandler is the default Handler used by Serve. +var DefaultHandler Handler + +// Option is a functional option handler for Server. +type Option func(*Server) error + +// Handler is a callback for handling established SSH sessions. +type Handler func(Session) + +// PublicKeyHandler is a callback for performing public key authentication. +type PublicKeyHandler func(ctx Context, key PublicKey) error + +type NoClientAuthHandler func(ctx Context) error + +type BannerHandler func(ctx Context) string + +// PasswordHandler is a callback for performing password authentication. +type PasswordHandler func(ctx Context, password string) bool + +// KeyboardInteractiveHandler is a callback for performing keyboard-interactive authentication. +type KeyboardInteractiveHandler func(ctx Context, challenger gossh.KeyboardInteractiveChallenge) bool + +// PtyCallback is a hook for allowing PTY sessions. +type PtyCallback func(ctx Context, pty Pty) bool + +// SessionRequestCallback is a callback for allowing or denying SSH sessions. +type SessionRequestCallback func(sess Session, requestType string) bool + +// ConnCallback is a hook for new connections before handling. +// It allows wrapping for timeouts and limiting by returning +// the net.Conn that will be used as the underlying connection. +type ConnCallback func(ctx Context, conn net.Conn) net.Conn + +// LocalPortForwardingCallback is a hook for allowing port forwarding +type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool + +// ReversePortForwardingCallback is a hook for allowing reverse port forwarding +type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool + +// ServerConfigCallback is a hook for creating custom default server configs +type ServerConfigCallback func(ctx Context) *gossh.ServerConfig + +// ConnectionFailedCallback is a hook for reporting failed connections +// Please note: the net.Conn is likely to be closed at this point +type ConnectionFailedCallback func(conn net.Conn, err error) + +// Window represents the size of a PTY window. +// +// See https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 +// +// Zero dimension parameters MUST be ignored. The character/row dimensions +// override the pixel dimensions (when nonzero). Pixel dimensions refer +// to the drawable area of the window. +type Window struct { + // Width is the number of columns. + // It overrides WidthPixels. + Width int + // Height is the number of rows. + // It overrides HeightPixels. + Height int + + // WidthPixels is the drawable width of the window, in pixels. + WidthPixels int + // HeightPixels is the drawable height of the window, in pixels. + HeightPixels int +} + +// Pty represents a PTY request and configuration. +type Pty struct { + // Term is the TERM environment variable value. + Term string + + // Window is the Window sent as part of the pty-req. + Window Window + + // Modes represent a mapping of Terminal Mode opcode to value as it was + // requested by the client as part of the pty-req. These are outlined as + // part of https://datatracker.ietf.org/doc/html/rfc4254#section-8. + // + // The opcodes are defined as constants in github.com/tailscale/golang-x-crypto/ssh (VINTR,VQUIT,etc.). + // Boolean opcodes have values 0 or 1. + Modes gossh.TerminalModes +} + +// Serve accepts incoming SSH connections on the listener l, creating a new +// connection goroutine for each. The connection goroutines read requests and +// then calls handler to handle sessions. Handler is typically nil, in which +// case the DefaultHandler is used. +func Serve(l net.Listener, handler Handler, options ...Option) error { + srv := &Server{Handler: handler} + for _, option := range options { + if err := srv.SetOption(option); err != nil { + return err + } + } + return srv.Serve(l) +} + +// ListenAndServe listens on the TCP network address addr and then calls Serve +// with handler to handle sessions on incoming connections. Handler is typically +// nil, in which case the DefaultHandler is used. +func ListenAndServe(addr string, handler Handler, options ...Option) error { + srv := &Server{Addr: addr, Handler: handler} + for _, option := range options { + if err := srv.SetOption(option); err != nil { + return err + } + } + return srv.ListenAndServe() +} + +// Handle registers the handler as the DefaultHandler. +func Handle(handler Handler) { + DefaultHandler = handler +} + +// KeysEqual is constant time compare of the keys to avoid timing attacks. +func KeysEqual(ak, bk PublicKey) bool { + + //avoid panic if one of the keys is nil, return false instead + if ak == nil || bk == nil { + return false + } + + a := ak.Marshal() + b := bk.Marshal() + return (len(a) == len(b) && subtle.ConstantTimeCompare(a, b) == 1) +} diff --git a/tempfork/gliderlabs/ssh/ssh_test.go b/tempfork/gliderlabs/ssh/ssh_test.go index 8772c03ad..aa301b048 100644 --- a/tempfork/gliderlabs/ssh/ssh_test.go +++ b/tempfork/gliderlabs/ssh/ssh_test.go @@ -1,17 +1,17 @@ -package ssh - -import ( - "testing" -) - -func TestKeysEqual(t *testing.T) { - defer func() { - if r := recover(); r != nil { - t.Errorf("The code did panic") - } - }() - - if KeysEqual(nil, nil) { - t.Error("two nil keys should not return true") - } -} +package ssh + +import ( + "testing" +) + +func TestKeysEqual(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("The code did panic") + } + }() + + if KeysEqual(nil, nil) { + t.Error("two nil keys should not return true") + } +} diff --git a/tempfork/gliderlabs/ssh/tcpip.go b/tempfork/gliderlabs/ssh/tcpip.go index d30bb15ac..056a0c734 100644 --- a/tempfork/gliderlabs/ssh/tcpip.go +++ b/tempfork/gliderlabs/ssh/tcpip.go @@ -1,193 +1,193 @@ -package ssh - -import ( - "io" - "log" - "net" - "strconv" - "sync" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -const ( - forwardedTCPChannelType = "forwarded-tcpip" -) - -// direct-tcpip data struct as specified in RFC4254, Section 7.2 -type localForwardChannelData struct { - DestAddr string - DestPort uint32 - - OriginAddr string - OriginPort uint32 -} - -// DirectTCPIPHandler can be enabled by adding it to the server's -// ChannelHandlers under direct-tcpip. -func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { - d := localForwardChannelData{} - if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { - newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error()) - return - } - - if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestAddr, d.DestPort) { - newChan.Reject(gossh.Prohibited, "port forwarding is disabled") - return - } - - dest := net.JoinHostPort(d.DestAddr, strconv.FormatInt(int64(d.DestPort), 10)) - - var dialer net.Dialer - dconn, err := dialer.DialContext(ctx, "tcp", dest) - if err != nil { - newChan.Reject(gossh.ConnectionFailed, err.Error()) - return - } - - ch, reqs, err := newChan.Accept() - if err != nil { - dconn.Close() - return - } - go gossh.DiscardRequests(reqs) - - go func() { - defer ch.Close() - defer dconn.Close() - io.Copy(ch, dconn) - }() - go func() { - defer ch.Close() - defer dconn.Close() - io.Copy(dconn, ch) - }() -} - -type remoteForwardRequest struct { - BindAddr string - BindPort uint32 -} - -type remoteForwardSuccess struct { - BindPort uint32 -} - -type remoteForwardCancelRequest struct { - BindAddr string - BindPort uint32 -} - -type remoteForwardChannelData struct { - DestAddr string - DestPort uint32 - OriginAddr string - OriginPort uint32 -} - -// ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and -// adding the HandleSSHRequest callback to the server's RequestHandlers under -// tcpip-forward and cancel-tcpip-forward. -type ForwardedTCPHandler struct { - forwards map[string]net.Listener - sync.Mutex -} - -func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { - h.Lock() - if h.forwards == nil { - h.forwards = make(map[string]net.Listener) - } - h.Unlock() - conn := ctx.Value(ContextKeyConn).(*gossh.ServerConn) - switch req.Type { - case "tcpip-forward": - var reqPayload remoteForwardRequest - if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { - // TODO: log parse failure - return false, []byte{} - } - if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, reqPayload.BindPort) { - return false, []byte("port forwarding is disabled") - } - addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) - ln, err := net.Listen("tcp", addr) - if err != nil { - // TODO: log listen failure - return false, []byte{} - } - _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) - destPort, _ := strconv.Atoi(destPortStr) - h.Lock() - h.forwards[addr] = ln - h.Unlock() - go func() { - <-ctx.Done() - h.Lock() - ln, ok := h.forwards[addr] - h.Unlock() - if ok { - ln.Close() - } - }() - go func() { - for { - c, err := ln.Accept() - if err != nil { - // TODO: log accept failure - break - } - originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String()) - originPort, _ := strconv.Atoi(orignPortStr) - payload := gossh.Marshal(&remoteForwardChannelData{ - DestAddr: reqPayload.BindAddr, - DestPort: uint32(destPort), - OriginAddr: originAddr, - OriginPort: uint32(originPort), - }) - go func() { - ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload) - if err != nil { - // TODO: log failure to open channel - log.Println(err) - c.Close() - return - } - go gossh.DiscardRequests(reqs) - go func() { - defer ch.Close() - defer c.Close() - io.Copy(ch, c) - }() - go func() { - defer ch.Close() - defer c.Close() - io.Copy(c, ch) - }() - }() - } - h.Lock() - delete(h.forwards, addr) - h.Unlock() - }() - return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)}) - - case "cancel-tcpip-forward": - var reqPayload remoteForwardCancelRequest - if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { - // TODO: log parse failure - return false, []byte{} - } - addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) - h.Lock() - ln, ok := h.forwards[addr] - h.Unlock() - if ok { - ln.Close() - } - return true, nil - default: - return false, nil - } -} +package ssh + +import ( + "io" + "log" + "net" + "strconv" + "sync" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +const ( + forwardedTCPChannelType = "forwarded-tcpip" +) + +// direct-tcpip data struct as specified in RFC4254, Section 7.2 +type localForwardChannelData struct { + DestAddr string + DestPort uint32 + + OriginAddr string + OriginPort uint32 +} + +// DirectTCPIPHandler can be enabled by adding it to the server's +// ChannelHandlers under direct-tcpip. +func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { + d := localForwardChannelData{} + if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { + newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error()) + return + } + + if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestAddr, d.DestPort) { + newChan.Reject(gossh.Prohibited, "port forwarding is disabled") + return + } + + dest := net.JoinHostPort(d.DestAddr, strconv.FormatInt(int64(d.DestPort), 10)) + + var dialer net.Dialer + dconn, err := dialer.DialContext(ctx, "tcp", dest) + if err != nil { + newChan.Reject(gossh.ConnectionFailed, err.Error()) + return + } + + ch, reqs, err := newChan.Accept() + if err != nil { + dconn.Close() + return + } + go gossh.DiscardRequests(reqs) + + go func() { + defer ch.Close() + defer dconn.Close() + io.Copy(ch, dconn) + }() + go func() { + defer ch.Close() + defer dconn.Close() + io.Copy(dconn, ch) + }() +} + +type remoteForwardRequest struct { + BindAddr string + BindPort uint32 +} + +type remoteForwardSuccess struct { + BindPort uint32 +} + +type remoteForwardCancelRequest struct { + BindAddr string + BindPort uint32 +} + +type remoteForwardChannelData struct { + DestAddr string + DestPort uint32 + OriginAddr string + OriginPort uint32 +} + +// ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and +// adding the HandleSSHRequest callback to the server's RequestHandlers under +// tcpip-forward and cancel-tcpip-forward. +type ForwardedTCPHandler struct { + forwards map[string]net.Listener + sync.Mutex +} + +func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { + h.Lock() + if h.forwards == nil { + h.forwards = make(map[string]net.Listener) + } + h.Unlock() + conn := ctx.Value(ContextKeyConn).(*gossh.ServerConn) + switch req.Type { + case "tcpip-forward": + var reqPayload remoteForwardRequest + if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { + // TODO: log parse failure + return false, []byte{} + } + if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, reqPayload.BindPort) { + return false, []byte("port forwarding is disabled") + } + addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) + ln, err := net.Listen("tcp", addr) + if err != nil { + // TODO: log listen failure + return false, []byte{} + } + _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) + destPort, _ := strconv.Atoi(destPortStr) + h.Lock() + h.forwards[addr] = ln + h.Unlock() + go func() { + <-ctx.Done() + h.Lock() + ln, ok := h.forwards[addr] + h.Unlock() + if ok { + ln.Close() + } + }() + go func() { + for { + c, err := ln.Accept() + if err != nil { + // TODO: log accept failure + break + } + originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String()) + originPort, _ := strconv.Atoi(orignPortStr) + payload := gossh.Marshal(&remoteForwardChannelData{ + DestAddr: reqPayload.BindAddr, + DestPort: uint32(destPort), + OriginAddr: originAddr, + OriginPort: uint32(originPort), + }) + go func() { + ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload) + if err != nil { + // TODO: log failure to open channel + log.Println(err) + c.Close() + return + } + go gossh.DiscardRequests(reqs) + go func() { + defer ch.Close() + defer c.Close() + io.Copy(ch, c) + }() + go func() { + defer ch.Close() + defer c.Close() + io.Copy(c, ch) + }() + }() + } + h.Lock() + delete(h.forwards, addr) + h.Unlock() + }() + return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)}) + + case "cancel-tcpip-forward": + var reqPayload remoteForwardCancelRequest + if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { + // TODO: log parse failure + return false, []byte{} + } + addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) + h.Lock() + ln, ok := h.forwards[addr] + h.Unlock() + if ok { + ln.Close() + } + return true, nil + default: + return false, nil + } +} diff --git a/tempfork/gliderlabs/ssh/tcpip_test.go b/tempfork/gliderlabs/ssh/tcpip_test.go index e1d74d566..118b5d53a 100644 --- a/tempfork/gliderlabs/ssh/tcpip_test.go +++ b/tempfork/gliderlabs/ssh/tcpip_test.go @@ -1,85 +1,85 @@ -//go:build glidertests - -package ssh - -import ( - "bytes" - "io" - "net" - "strconv" - "strings" - "testing" - - gossh "github.com/tailscale/golang-x-crypto/ssh" -) - -var sampleServerResponse = []byte("Hello world") - -func sampleSocketServer() net.Listener { - l := newLocalListener() - - go func() { - conn, err := l.Accept() - if err != nil { - return - } - conn.Write(sampleServerResponse) - conn.Close() - }() - - return l -} - -func newTestSessionWithForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) { - l := sampleSocketServer() - - _, client, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) {}, - LocalPortForwardingCallback: func(ctx Context, destinationHost string, destinationPort uint32) bool { - addr := net.JoinHostPort(destinationHost, strconv.FormatInt(int64(destinationPort), 10)) - if addr != l.Addr().String() { - panic("unexpected destinationHost: " + addr) - } - return forwardingEnabled - }, - }, nil) - - return l, client, func() { - cleanup() - l.Close() - } -} - -func TestLocalPortForwardingWorks(t *testing.T) { - t.Parallel() - - l, client, cleanup := newTestSessionWithForwarding(t, true) - defer cleanup() - - conn, err := client.Dial("tcp", l.Addr().String()) - if err != nil { - t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err) - } - result, err := io.ReadAll(conn) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(result, sampleServerResponse) { - t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) - } -} - -func TestLocalPortForwardingRespectsCallback(t *testing.T) { - t.Parallel() - - l, client, cleanup := newTestSessionWithForwarding(t, false) - defer cleanup() - - _, err := client.Dial("tcp", l.Addr().String()) - if err == nil { - t.Fatalf("Expected error connecting to %v but it succeeded", l.Addr().String()) - } - if !strings.Contains(err.Error(), "port forwarding is disabled") { - t.Fatalf("Expected permission error but got %#v", err) - } -} +//go:build glidertests + +package ssh + +import ( + "bytes" + "io" + "net" + "strconv" + "strings" + "testing" + + gossh "github.com/tailscale/golang-x-crypto/ssh" +) + +var sampleServerResponse = []byte("Hello world") + +func sampleSocketServer() net.Listener { + l := newLocalListener() + + go func() { + conn, err := l.Accept() + if err != nil { + return + } + conn.Write(sampleServerResponse) + conn.Close() + }() + + return l +} + +func newTestSessionWithForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) { + l := sampleSocketServer() + + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + LocalPortForwardingCallback: func(ctx Context, destinationHost string, destinationPort uint32) bool { + addr := net.JoinHostPort(destinationHost, strconv.FormatInt(int64(destinationPort), 10)) + if addr != l.Addr().String() { + panic("unexpected destinationHost: " + addr) + } + return forwardingEnabled + }, + }, nil) + + return l, client, func() { + cleanup() + l.Close() + } +} + +func TestLocalPortForwardingWorks(t *testing.T) { + t.Parallel() + + l, client, cleanup := newTestSessionWithForwarding(t, true) + defer cleanup() + + conn, err := client.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err) + } + result, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, sampleServerResponse) { + t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) + } +} + +func TestLocalPortForwardingRespectsCallback(t *testing.T) { + t.Parallel() + + l, client, cleanup := newTestSessionWithForwarding(t, false) + defer cleanup() + + _, err := client.Dial("tcp", l.Addr().String()) + if err == nil { + t.Fatalf("Expected error connecting to %v but it succeeded", l.Addr().String()) + } + if !strings.Contains(err.Error(), "port forwarding is disabled") { + t.Fatalf("Expected permission error but got %#v", err) + } +} diff --git a/tempfork/gliderlabs/ssh/util.go b/tempfork/gliderlabs/ssh/util.go index 7a6a18241..e3b5716a3 100644 --- a/tempfork/gliderlabs/ssh/util.go +++ b/tempfork/gliderlabs/ssh/util.go @@ -1,157 +1,157 @@ -package ssh - -import ( - "crypto/rand" - "crypto/rsa" - "encoding/binary" - - "github.com/tailscale/golang-x-crypto/ssh" -) - -func generateSigner() (ssh.Signer, error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, err - } - return ssh.NewSignerFromKey(key) -} - -func parsePtyRequest(payload []byte) (pty Pty, ok bool) { - // See https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 - // 6.2. Requesting a Pseudo-Terminal - // A pseudo-terminal can be allocated for the session by sending the - // following message. - // byte SSH_MSG_CHANNEL_REQUEST - // uint32 recipient channel - // string "pty-req" - // boolean want_reply - // string TERM environment variable value (e.g., vt100) - // uint32 terminal width, characters (e.g., 80) - // uint32 terminal height, rows (e.g., 24) - // uint32 terminal width, pixels (e.g., 640) - // uint32 terminal height, pixels (e.g., 480) - // string encoded terminal modes - - // The payload starts from the TERM variable. - term, rem, ok := parseString(payload) - if !ok { - return - } - win, rem, ok := parseWindow(rem) - if !ok { - return - } - modes, ok := parseTerminalModes(rem) - if !ok { - return - } - pty = Pty{ - Term: term, - Window: win, - Modes: modes, - } - return -} - -func parseTerminalModes(in []byte) (modes ssh.TerminalModes, ok bool) { - // See https://datatracker.ietf.org/doc/html/rfc4254#section-8 - // 8. Encoding of Terminal Modes - // - // All 'encoded terminal modes' (as passed in a pty request) are encoded - // into a byte stream. It is intended that the coding be portable - // across different environments. The stream consists of opcode- - // argument pairs wherein the opcode is a byte value. Opcodes 1 to 159 - // have a single uint32 argument. Opcodes 160 to 255 are not yet - // defined, and cause parsing to stop (they should only be used after - // any other data). The stream is terminated by opcode TTY_OP_END - // (0x00). - // - // The client SHOULD put any modes it knows about in the stream, and the - // server MAY ignore any modes it does not know about. This allows some - // degree of machine-independence, at least between systems that use a - // POSIX-like tty interface. The protocol can support other systems as - // well, but the client may need to fill reasonable values for a number - // of parameters so the server pty gets set to a reasonable mode (the - // server leaves all unspecified mode bits in their default values, and - // only some combinations make sense). - _, rem, ok := parseUint32(in) - if !ok { - return - } - const ttyOpEnd = 0 - for len(rem) > 0 { - if modes == nil { - modes = make(ssh.TerminalModes) - } - code := uint8(rem[0]) - rem = rem[1:] - if code == ttyOpEnd || code > 160 { - break - } - var val uint32 - val, rem, ok = parseUint32(rem) - if !ok { - return - } - modes[code] = val - } - ok = true - return -} - -func parseWindow(s []byte) (win Window, rem []byte, ok bool) { - // See https://datatracker.ietf.org/doc/html/rfc4254#section-6.7 - // 6.7. Window Dimension Change Message - // When the window (terminal) size changes on the client side, it MAY - // send a message to the other side to inform it of the new dimensions. - - // byte SSH_MSG_CHANNEL_REQUEST - // uint32 recipient channel - // string "window-change" - // boolean FALSE - // uint32 terminal width, columns - // uint32 terminal height, rows - // uint32 terminal width, pixels - // uint32 terminal height, pixels - wCols, rem, ok := parseUint32(s) - if !ok { - return - } - hRows, rem, ok := parseUint32(rem) - if !ok { - return - } - wPixels, rem, ok := parseUint32(rem) - if !ok { - return - } - hPixels, rem, ok := parseUint32(rem) - if !ok { - return - } - win = Window{ - Width: int(wCols), - Height: int(hRows), - WidthPixels: int(wPixels), - HeightPixels: int(hPixels), - } - return -} - -func parseString(in []byte) (out string, rem []byte, ok bool) { - length, rem, ok := parseUint32(in) - if uint32(len(rem)) < length || !ok { - ok = false - return - } - out, rem = string(rem[:length]), rem[length:] - ok = true - return -} - -func parseUint32(in []byte) (uint32, []byte, bool) { - if len(in) < 4 { - return 0, nil, false - } - return binary.BigEndian.Uint32(in), in[4:], true -} +package ssh + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/binary" + + "github.com/tailscale/golang-x-crypto/ssh" +) + +func generateSigner() (ssh.Signer, error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + return ssh.NewSignerFromKey(key) +} + +func parsePtyRequest(payload []byte) (pty Pty, ok bool) { + // See https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 + // 6.2. Requesting a Pseudo-Terminal + // A pseudo-terminal can be allocated for the session by sending the + // following message. + // byte SSH_MSG_CHANNEL_REQUEST + // uint32 recipient channel + // string "pty-req" + // boolean want_reply + // string TERM environment variable value (e.g., vt100) + // uint32 terminal width, characters (e.g., 80) + // uint32 terminal height, rows (e.g., 24) + // uint32 terminal width, pixels (e.g., 640) + // uint32 terminal height, pixels (e.g., 480) + // string encoded terminal modes + + // The payload starts from the TERM variable. + term, rem, ok := parseString(payload) + if !ok { + return + } + win, rem, ok := parseWindow(rem) + if !ok { + return + } + modes, ok := parseTerminalModes(rem) + if !ok { + return + } + pty = Pty{ + Term: term, + Window: win, + Modes: modes, + } + return +} + +func parseTerminalModes(in []byte) (modes ssh.TerminalModes, ok bool) { + // See https://datatracker.ietf.org/doc/html/rfc4254#section-8 + // 8. Encoding of Terminal Modes + // + // All 'encoded terminal modes' (as passed in a pty request) are encoded + // into a byte stream. It is intended that the coding be portable + // across different environments. The stream consists of opcode- + // argument pairs wherein the opcode is a byte value. Opcodes 1 to 159 + // have a single uint32 argument. Opcodes 160 to 255 are not yet + // defined, and cause parsing to stop (they should only be used after + // any other data). The stream is terminated by opcode TTY_OP_END + // (0x00). + // + // The client SHOULD put any modes it knows about in the stream, and the + // server MAY ignore any modes it does not know about. This allows some + // degree of machine-independence, at least between systems that use a + // POSIX-like tty interface. The protocol can support other systems as + // well, but the client may need to fill reasonable values for a number + // of parameters so the server pty gets set to a reasonable mode (the + // server leaves all unspecified mode bits in their default values, and + // only some combinations make sense). + _, rem, ok := parseUint32(in) + if !ok { + return + } + const ttyOpEnd = 0 + for len(rem) > 0 { + if modes == nil { + modes = make(ssh.TerminalModes) + } + code := uint8(rem[0]) + rem = rem[1:] + if code == ttyOpEnd || code > 160 { + break + } + var val uint32 + val, rem, ok = parseUint32(rem) + if !ok { + return + } + modes[code] = val + } + ok = true + return +} + +func parseWindow(s []byte) (win Window, rem []byte, ok bool) { + // See https://datatracker.ietf.org/doc/html/rfc4254#section-6.7 + // 6.7. Window Dimension Change Message + // When the window (terminal) size changes on the client side, it MAY + // send a message to the other side to inform it of the new dimensions. + + // byte SSH_MSG_CHANNEL_REQUEST + // uint32 recipient channel + // string "window-change" + // boolean FALSE + // uint32 terminal width, columns + // uint32 terminal height, rows + // uint32 terminal width, pixels + // uint32 terminal height, pixels + wCols, rem, ok := parseUint32(s) + if !ok { + return + } + hRows, rem, ok := parseUint32(rem) + if !ok { + return + } + wPixels, rem, ok := parseUint32(rem) + if !ok { + return + } + hPixels, rem, ok := parseUint32(rem) + if !ok { + return + } + win = Window{ + Width: int(wCols), + Height: int(hRows), + WidthPixels: int(wPixels), + HeightPixels: int(hPixels), + } + return +} + +func parseString(in []byte) (out string, rem []byte, ok bool) { + length, rem, ok := parseUint32(in) + if uint32(len(rem)) < length || !ok { + ok = false + return + } + out, rem = string(rem[:length]), rem[length:] + ok = true + return +} + +func parseUint32(in []byte) (uint32, []byte, bool) { + if len(in) < 4 { + return 0, nil, false + } + return binary.BigEndian.Uint32(in), in[4:], true +} diff --git a/tempfork/gliderlabs/ssh/wrap.go b/tempfork/gliderlabs/ssh/wrap.go index f44f5d9bf..17867d751 100644 --- a/tempfork/gliderlabs/ssh/wrap.go +++ b/tempfork/gliderlabs/ssh/wrap.go @@ -1,33 +1,33 @@ -package ssh - -import gossh "github.com/tailscale/golang-x-crypto/ssh" - -// PublicKey is an abstraction of different types of public keys. -type PublicKey interface { - gossh.PublicKey -} - -// The Permissions type holds fine-grained permissions that are specific to a -// user or a specific authentication method for a user. Permissions, except for -// "source-address", must be enforced in the server application layer, after -// successful authentication. -type Permissions struct { - *gossh.Permissions -} - -// A Signer can create signatures that verify against a public key. -type Signer interface { - gossh.Signer -} - -// ParseAuthorizedKey parses a public key from an authorized_keys file used in -// OpenSSH according to the sshd(8) manual page. -func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { - return gossh.ParseAuthorizedKey(in) -} - -// ParsePublicKey parses an SSH public key formatted for use in -// the SSH wire protocol according to RFC 4253, section 6.6. -func ParsePublicKey(in []byte) (out PublicKey, err error) { - return gossh.ParsePublicKey(in) -} +package ssh + +import gossh "github.com/tailscale/golang-x-crypto/ssh" + +// PublicKey is an abstraction of different types of public keys. +type PublicKey interface { + gossh.PublicKey +} + +// The Permissions type holds fine-grained permissions that are specific to a +// user or a specific authentication method for a user. Permissions, except for +// "source-address", must be enforced in the server application layer, after +// successful authentication. +type Permissions struct { + *gossh.Permissions +} + +// A Signer can create signatures that verify against a public key. +type Signer interface { + gossh.Signer +} + +// ParseAuthorizedKey parses a public key from an authorized_keys file used in +// OpenSSH according to the sshd(8) manual page. +func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { + return gossh.ParseAuthorizedKey(in) +} + +// ParsePublicKey parses an SSH public key formatted for use in +// the SSH wire protocol according to RFC 4253, section 6.6. +func ParsePublicKey(in []byte) (out PublicKey, err error) { + return gossh.ParsePublicKey(in) +} diff --git a/tempfork/heap/heap.go b/tempfork/heap/heap.go index 080b80ca5..3dfab492a 100644 --- a/tempfork/heap/heap.go +++ b/tempfork/heap/heap.go @@ -1,121 +1,121 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package heap provides heap operations for any type that implements -// heap.Interface. A heap is a tree with the property that each node is the -// minimum-valued node in its subtree. -// -// The minimum element in the tree is the root, at index 0. -// -// A heap is a common way to implement a priority queue. To build a priority -// queue, implement the Heap interface with the (negative) priority as the -// ordering for the Less method, so Push adds items while Pop removes the -// highest-priority item from the queue. The Examples include such an -// implementation; the file example_pq_test.go has the complete source. -// -// This package is a copy of the Go standard library's -// container/heap, but using generics. -package heap - -import "sort" - -// The Interface type describes the requirements -// for a type using the routines in this package. -// Any type that implements it may be used as a -// min-heap with the following invariants (established after -// Init has been called or if the data is empty or sorted): -// -// !h.Less(j, i) for 0 <= i < h.Len() and 2*i+1 <= j <= 2*i+2 and j < h.Len() -// -// Note that Push and Pop in this interface are for package heap's -// implementation to call. To add and remove things from the heap, -// use heap.Push and heap.Pop. -type Interface[V any] interface { - sort.Interface - Push(x V) // add x as element Len() - Pop() V // remove and return element Len() - 1. -} - -// Init establishes the heap invariants required by the other routines in this package. -// Init is idempotent with respect to the heap invariants -// and may be called whenever the heap invariants may have been invalidated. -// The complexity is O(n) where n = h.Len(). -func Init[V any](h Interface[V]) { - // heapify - n := h.Len() - for i := n/2 - 1; i >= 0; i-- { - down(h, i, n) - } -} - -// Push pushes the element x onto the heap. -// The complexity is O(log n) where n = h.Len(). -func Push[V any](h Interface[V], x V) { - h.Push(x) - up(h, h.Len()-1) -} - -// Pop removes and returns the minimum element (according to Less) from the heap. -// The complexity is O(log n) where n = h.Len(). -// Pop is equivalent to Remove(h, 0). -func Pop[V any](h Interface[V]) V { - n := h.Len() - 1 - h.Swap(0, n) - down(h, 0, n) - return h.Pop() -} - -// Remove removes and returns the element at index i from the heap. -// The complexity is O(log n) where n = h.Len(). -func Remove[V any](h Interface[V], i int) V { - n := h.Len() - 1 - if n != i { - h.Swap(i, n) - if !down(h, i, n) { - up(h, i) - } - } - return h.Pop() -} - -// Fix re-establishes the heap ordering after the element at index i has changed its value. -// Changing the value of the element at index i and then calling Fix is equivalent to, -// but less expensive than, calling Remove(h, i) followed by a Push of the new value. -// The complexity is O(log n) where n = h.Len(). -func Fix[V any](h Interface[V], i int) { - if !down(h, i, h.Len()) { - up(h, i) - } -} - -func up[V any](h Interface[V], j int) { - for { - i := (j - 1) / 2 // parent - if i == j || !h.Less(j, i) { - break - } - h.Swap(i, j) - j = i - } -} - -func down[V any](h Interface[V], i0, n int) bool { - i := i0 - for { - j1 := 2*i + 1 - if j1 >= n || j1 < 0 { // j1 < 0 after int overflow - break - } - j := j1 // left child - if j2 := j1 + 1; j2 < n && h.Less(j2, j1) { - j = j2 // = 2*i + 2 // right child - } - if !h.Less(j, i) { - break - } - h.Swap(i, j) - i = j - } - return i > i0 -} +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package heap provides heap operations for any type that implements +// heap.Interface. A heap is a tree with the property that each node is the +// minimum-valued node in its subtree. +// +// The minimum element in the tree is the root, at index 0. +// +// A heap is a common way to implement a priority queue. To build a priority +// queue, implement the Heap interface with the (negative) priority as the +// ordering for the Less method, so Push adds items while Pop removes the +// highest-priority item from the queue. The Examples include such an +// implementation; the file example_pq_test.go has the complete source. +// +// This package is a copy of the Go standard library's +// container/heap, but using generics. +package heap + +import "sort" + +// The Interface type describes the requirements +// for a type using the routines in this package. +// Any type that implements it may be used as a +// min-heap with the following invariants (established after +// Init has been called or if the data is empty or sorted): +// +// !h.Less(j, i) for 0 <= i < h.Len() and 2*i+1 <= j <= 2*i+2 and j < h.Len() +// +// Note that Push and Pop in this interface are for package heap's +// implementation to call. To add and remove things from the heap, +// use heap.Push and heap.Pop. +type Interface[V any] interface { + sort.Interface + Push(x V) // add x as element Len() + Pop() V // remove and return element Len() - 1. +} + +// Init establishes the heap invariants required by the other routines in this package. +// Init is idempotent with respect to the heap invariants +// and may be called whenever the heap invariants may have been invalidated. +// The complexity is O(n) where n = h.Len(). +func Init[V any](h Interface[V]) { + // heapify + n := h.Len() + for i := n/2 - 1; i >= 0; i-- { + down(h, i, n) + } +} + +// Push pushes the element x onto the heap. +// The complexity is O(log n) where n = h.Len(). +func Push[V any](h Interface[V], x V) { + h.Push(x) + up(h, h.Len()-1) +} + +// Pop removes and returns the minimum element (according to Less) from the heap. +// The complexity is O(log n) where n = h.Len(). +// Pop is equivalent to Remove(h, 0). +func Pop[V any](h Interface[V]) V { + n := h.Len() - 1 + h.Swap(0, n) + down(h, 0, n) + return h.Pop() +} + +// Remove removes and returns the element at index i from the heap. +// The complexity is O(log n) where n = h.Len(). +func Remove[V any](h Interface[V], i int) V { + n := h.Len() - 1 + if n != i { + h.Swap(i, n) + if !down(h, i, n) { + up(h, i) + } + } + return h.Pop() +} + +// Fix re-establishes the heap ordering after the element at index i has changed its value. +// Changing the value of the element at index i and then calling Fix is equivalent to, +// but less expensive than, calling Remove(h, i) followed by a Push of the new value. +// The complexity is O(log n) where n = h.Len(). +func Fix[V any](h Interface[V], i int) { + if !down(h, i, h.Len()) { + up(h, i) + } +} + +func up[V any](h Interface[V], j int) { + for { + i := (j - 1) / 2 // parent + if i == j || !h.Less(j, i) { + break + } + h.Swap(i, j) + j = i + } +} + +func down[V any](h Interface[V], i0, n int) bool { + i := i0 + for { + j1 := 2*i + 1 + if j1 >= n || j1 < 0 { // j1 < 0 after int overflow + break + } + j := j1 // left child + if j2 := j1 + 1; j2 < n && h.Less(j2, j1) { + j = j2 // = 2*i + 2 // right child + } + if !h.Less(j, i) { + break + } + h.Swap(i, j) + i = j + } + return i > i0 +} diff --git a/tka/aum_test.go b/tka/aum_test.go index 84b567477..4297efabf 100644 --- a/tka/aum_test.go +++ b/tka/aum_test.go @@ -1,253 +1,253 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "testing" - - "github.com/google/go-cmp/cmp" - "golang.org/x/crypto/blake2s" - "tailscale.com/types/tkatype" -) - -func TestSerialization(t *testing.T) { - uint2 := uint(2) - var fakeAUMHash AUMHash - - tcs := []struct { - Name string - AUM AUM - Expect []byte - }{ - { - "AddKey", - AUM{MessageKind: AUMAddKey, Key: &Key{}}, - []byte{ - 0xa3, // major type 5 (map), 3 items - 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) - 0x01, // |- major type 0 (int), value 1 (first value, AUMAddKey) - 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) - 0xf6, // |- major type 7 (val), value null (second value, nil) - 0x03, // |- major type 0 (int), value 3 (third key, Key) - 0xa3, // |- major type 5 (map), 3 items (type Key) - 0x01, // |- major type 0 (int), value 1 (first key, Kind) - 0x00, // |- major type 0 (int), value 0 (first value) - 0x02, // |- major type 0 (int), value 2 (second key, Votes) - 0x00, // |- major type 0 (int), value 0 (first value) - 0x03, // |- major type 0 (int), value 3 (third key, Public) - 0xf6, // |- major type 7 (val), value null (third value, nil) - }, - }, - { - "RemoveKey", - AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}}, - []byte{ - 0xa3, // major type 5 (map), 3 items - 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) - 0x02, // |- major type 0 (int), value 2 (first value, AUMRemoveKey) - 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) - 0xf6, // |- major type 7 (val), value null (second value, nil) - 0x04, // |- major type 0 (int), value 4 (third key, KeyID) - 0x42, // |- major type 2 (byte string), 2 items - 0x01, // |- major type 0 (int), value 1 (byte 1) - 0x02, // |- major type 0 (int), value 2 (byte 2) - }, - }, - { - "UpdateKey", - AUM{MessageKind: AUMUpdateKey, Votes: &uint2, KeyID: []byte{1, 2}, Meta: map[string]string{"a": "b"}}, - []byte{ - 0xa5, // major type 5 (map), 5 items - 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) - 0x04, // |- major type 0 (int), value 4 (first value, AUMUpdateKey) - 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) - 0xf6, // |- major type 7 (val), value null (second value, nil) - 0x04, // |- major type 0 (int), value 4 (third key, KeyID) - 0x42, // |- major type 2 (byte string), 2 items - 0x01, // |- major type 0 (int), value 1 (byte 1) - 0x02, // |- major type 0 (int), value 2 (byte 2) - 0x06, // |- major type 0 (int), value 6 (fourth key, Votes) - 0x02, // |- major type 0 (int), value 2 (forth value, 2) - 0x07, // |- major type 0 (int), value 7 (fifth key, Meta) - 0xa1, // |- major type 5 (map), 1 item (map[string]string type) - 0x61, // |- major type 3 (text string), value 1 (first key, one byte long) - 0x61, // |- byte 'a' - 0x61, // |- major type 3 (text string), value 1 (first value, one byte long) - 0x62, // |- byte 'b' - }, - }, - { - "Checkpoint", - AUM{MessageKind: AUMCheckpoint, PrevAUMHash: []byte{1, 2}, State: &State{ - LastAUMHash: &fakeAUMHash, - Keys: []Key{ - {Kind: Key25519, Public: []byte{5, 6}}, - }, - }}, - append( - append([]byte{ - 0xa3, // major type 5 (map), 3 items - 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) - 0x05, // |- major type 0 (int), value 5 (first value, AUMCheckpoint) - 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) - 0x42, // |- major type 2 (byte string), 2 items (second value) - 0x01, // |- major type 0 (int), value 1 (byte 1) - 0x02, // |- major type 0 (int), value 2 (byte 2) - 0x05, // |- major type 0 (int), value 5 (third key, State) - 0xa3, // |- major type 5 (map), 3 items (third value, State type) - 0x01, // |- major type 0 (int), value 1 (first key, LastAUMHash) - 0x58, 0x20, // |- major type 2 (byte string), 32 items (first value) - }, - bytes.Repeat([]byte{0}, 32)...), - []byte{ - 0x02, // |- major type 0 (int), value 2 (second key, DisablementSecrets) - 0xf6, // |- major type 7 (val), value null (second value, nil) - 0x03, // |- major type 0 (int), value 3 (third key, Keys) - 0x81, // |- major type 4 (array), value 1 (one item in array) - 0xa3, // |- major type 5 (map), 3 items (Key type) - 0x01, // |- major type 0 (int), value 1 (first key, Kind) - 0x01, // |- major type 0 (int), value 1 (first value, Key25519) - 0x02, // |- major type 0 (int), value 2 (second key, Votes) - 0x00, // |- major type 0 (int), value 0 (second value, 0) - 0x03, // |- major type 0 (int), value 3 (third key, Public) - 0x42, // |- major type 2 (byte string), 2 items (third value) - 0x05, // |- major type 0 (int), value 5 (byte 5) - 0x06, // |- major type 0 (int), value 6 (byte 6) - }...), - }, - { - "Signature", - AUM{MessageKind: AUMAddKey, Signatures: []tkatype.Signature{{KeyID: []byte{1}}}}, - []byte{ - 0xa3, // major type 5 (map), 3 items - 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) - 0x01, // |- major type 0 (int), value 1 (first value, AUMAddKey) - 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) - 0xf6, // |- major type 7 (val), value null (second value, nil) - 0x17, // |- major type 0 (int), value 22 (third key, Signatures) - 0x81, // |- major type 4 (array), value 1 (one item in array) - 0xa2, // |- major type 5 (map), 2 items (Signature type) - 0x01, // |- major type 0 (int), value 1 (first key, KeyID) - 0x41, // |- major type 2 (byte string), 1 item - 0x01, // |- major type 0 (int), value 1 (byte 1) - 0x02, // |- major type 0 (int), value 2 (second key, Signature) - 0xf6, // |- major type 7 (val), value null (second value, nil) - }, - }, - } - - for _, tc := range tcs { - t.Run(tc.Name, func(t *testing.T) { - data := []byte(tc.AUM.Serialize()) - if diff := cmp.Diff(tc.Expect, data); diff != "" { - t.Errorf("serialization differs (-want, +got):\n%s", diff) - } - - var decodedAUM AUM - if err := decodedAUM.Unserialize(data); err != nil { - t.Fatalf("Unmarshal failed: %v", err) - } - if diff := cmp.Diff(tc.AUM, decodedAUM); diff != "" { - t.Errorf("unmarshalled version differs (-want, +got):\n%s", diff) - } - }) - } -} - -func TestAUMWeight(t *testing.T) { - var fakeKeyID [blake2s.Size]byte - testingRand(t, 1).Read(fakeKeyID[:]) - - pub, _ := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - pub, _ = testingKey25519(t, 2) - key2 := Key{Kind: Key25519, Public: pub, Votes: 2} - - tcs := []struct { - Name string - AUM AUM - State State - Want uint - }{ - { - "Empty", - AUM{}, - State{}, - 0, - }, - { - "Key unknown", - AUM{ - Signatures: []tkatype.Signature{{KeyID: fakeKeyID[:]}}, - }, - State{}, - 0, - }, - { - "Unary key", - AUM{ - Signatures: []tkatype.Signature{{KeyID: key.MustID()}}, - }, - State{ - Keys: []Key{key}, - }, - 2, - }, - { - "Multiple keys", - AUM{ - Signatures: []tkatype.Signature{{KeyID: key.MustID()}, {KeyID: key2.MustID()}}, - }, - State{ - Keys: []Key{key, key2}, - }, - 4, - }, - { - "Double use", - AUM{ - Signatures: []tkatype.Signature{{KeyID: key.MustID()}, {KeyID: key.MustID()}}, - }, - State{ - Keys: []Key{key}, - }, - 2, - }, - } - - for _, tc := range tcs { - t.Run(tc.Name, func(t *testing.T) { - got := tc.AUM.Weight(tc.State) - if got != tc.Want { - t.Errorf("Weight() = %d, want %d", got, tc.Want) - } - }) - } -} - -func TestAUMHashes(t *testing.T) { - // .Hash(): a hash over everything. - // .SigHash(): a hash over everything except the signatures. - // The signatures are over a hash of the AUM, so - // using SigHash() breaks this circularity. - - aum := AUM{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519}} - sigHash1 := aum.SigHash() - aumHash1 := aum.Hash() - - aum.Signatures = []tkatype.Signature{{KeyID: []byte{1, 2, 3, 4}}} - sigHash2 := aum.SigHash() - aumHash2 := aum.Hash() - if len(aum.Signatures) != 1 { - t.Error("signature was removed by one of the hash functions") - } - - if !bytes.Equal(sigHash1[:], sigHash1[:]) { - t.Errorf("signature hash dependent on signatures!\n\t1 = %x\n\t2 = %x", sigHash1, sigHash2) - } - if bytes.Equal(aumHash1[:], aumHash2[:]) { - t.Error("aum hash didnt change") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "testing" + + "github.com/google/go-cmp/cmp" + "golang.org/x/crypto/blake2s" + "tailscale.com/types/tkatype" +) + +func TestSerialization(t *testing.T) { + uint2 := uint(2) + var fakeAUMHash AUMHash + + tcs := []struct { + Name string + AUM AUM + Expect []byte + }{ + { + "AddKey", + AUM{MessageKind: AUMAddKey, Key: &Key{}}, + []byte{ + 0xa3, // major type 5 (map), 3 items + 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) + 0x01, // |- major type 0 (int), value 1 (first value, AUMAddKey) + 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) + 0xf6, // |- major type 7 (val), value null (second value, nil) + 0x03, // |- major type 0 (int), value 3 (third key, Key) + 0xa3, // |- major type 5 (map), 3 items (type Key) + 0x01, // |- major type 0 (int), value 1 (first key, Kind) + 0x00, // |- major type 0 (int), value 0 (first value) + 0x02, // |- major type 0 (int), value 2 (second key, Votes) + 0x00, // |- major type 0 (int), value 0 (first value) + 0x03, // |- major type 0 (int), value 3 (third key, Public) + 0xf6, // |- major type 7 (val), value null (third value, nil) + }, + }, + { + "RemoveKey", + AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}}, + []byte{ + 0xa3, // major type 5 (map), 3 items + 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) + 0x02, // |- major type 0 (int), value 2 (first value, AUMRemoveKey) + 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) + 0xf6, // |- major type 7 (val), value null (second value, nil) + 0x04, // |- major type 0 (int), value 4 (third key, KeyID) + 0x42, // |- major type 2 (byte string), 2 items + 0x01, // |- major type 0 (int), value 1 (byte 1) + 0x02, // |- major type 0 (int), value 2 (byte 2) + }, + }, + { + "UpdateKey", + AUM{MessageKind: AUMUpdateKey, Votes: &uint2, KeyID: []byte{1, 2}, Meta: map[string]string{"a": "b"}}, + []byte{ + 0xa5, // major type 5 (map), 5 items + 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) + 0x04, // |- major type 0 (int), value 4 (first value, AUMUpdateKey) + 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) + 0xf6, // |- major type 7 (val), value null (second value, nil) + 0x04, // |- major type 0 (int), value 4 (third key, KeyID) + 0x42, // |- major type 2 (byte string), 2 items + 0x01, // |- major type 0 (int), value 1 (byte 1) + 0x02, // |- major type 0 (int), value 2 (byte 2) + 0x06, // |- major type 0 (int), value 6 (fourth key, Votes) + 0x02, // |- major type 0 (int), value 2 (forth value, 2) + 0x07, // |- major type 0 (int), value 7 (fifth key, Meta) + 0xa1, // |- major type 5 (map), 1 item (map[string]string type) + 0x61, // |- major type 3 (text string), value 1 (first key, one byte long) + 0x61, // |- byte 'a' + 0x61, // |- major type 3 (text string), value 1 (first value, one byte long) + 0x62, // |- byte 'b' + }, + }, + { + "Checkpoint", + AUM{MessageKind: AUMCheckpoint, PrevAUMHash: []byte{1, 2}, State: &State{ + LastAUMHash: &fakeAUMHash, + Keys: []Key{ + {Kind: Key25519, Public: []byte{5, 6}}, + }, + }}, + append( + append([]byte{ + 0xa3, // major type 5 (map), 3 items + 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) + 0x05, // |- major type 0 (int), value 5 (first value, AUMCheckpoint) + 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) + 0x42, // |- major type 2 (byte string), 2 items (second value) + 0x01, // |- major type 0 (int), value 1 (byte 1) + 0x02, // |- major type 0 (int), value 2 (byte 2) + 0x05, // |- major type 0 (int), value 5 (third key, State) + 0xa3, // |- major type 5 (map), 3 items (third value, State type) + 0x01, // |- major type 0 (int), value 1 (first key, LastAUMHash) + 0x58, 0x20, // |- major type 2 (byte string), 32 items (first value) + }, + bytes.Repeat([]byte{0}, 32)...), + []byte{ + 0x02, // |- major type 0 (int), value 2 (second key, DisablementSecrets) + 0xf6, // |- major type 7 (val), value null (second value, nil) + 0x03, // |- major type 0 (int), value 3 (third key, Keys) + 0x81, // |- major type 4 (array), value 1 (one item in array) + 0xa3, // |- major type 5 (map), 3 items (Key type) + 0x01, // |- major type 0 (int), value 1 (first key, Kind) + 0x01, // |- major type 0 (int), value 1 (first value, Key25519) + 0x02, // |- major type 0 (int), value 2 (second key, Votes) + 0x00, // |- major type 0 (int), value 0 (second value, 0) + 0x03, // |- major type 0 (int), value 3 (third key, Public) + 0x42, // |- major type 2 (byte string), 2 items (third value) + 0x05, // |- major type 0 (int), value 5 (byte 5) + 0x06, // |- major type 0 (int), value 6 (byte 6) + }...), + }, + { + "Signature", + AUM{MessageKind: AUMAddKey, Signatures: []tkatype.Signature{{KeyID: []byte{1}}}}, + []byte{ + 0xa3, // major type 5 (map), 3 items + 0x01, // |- major type 0 (int), value 1 (first key, MessageKind) + 0x01, // |- major type 0 (int), value 1 (first value, AUMAddKey) + 0x02, // |- major type 0 (int), value 2 (second key, PrevAUMHash) + 0xf6, // |- major type 7 (val), value null (second value, nil) + 0x17, // |- major type 0 (int), value 22 (third key, Signatures) + 0x81, // |- major type 4 (array), value 1 (one item in array) + 0xa2, // |- major type 5 (map), 2 items (Signature type) + 0x01, // |- major type 0 (int), value 1 (first key, KeyID) + 0x41, // |- major type 2 (byte string), 1 item + 0x01, // |- major type 0 (int), value 1 (byte 1) + 0x02, // |- major type 0 (int), value 2 (second key, Signature) + 0xf6, // |- major type 7 (val), value null (second value, nil) + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.Name, func(t *testing.T) { + data := []byte(tc.AUM.Serialize()) + if diff := cmp.Diff(tc.Expect, data); diff != "" { + t.Errorf("serialization differs (-want, +got):\n%s", diff) + } + + var decodedAUM AUM + if err := decodedAUM.Unserialize(data); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if diff := cmp.Diff(tc.AUM, decodedAUM); diff != "" { + t.Errorf("unmarshalled version differs (-want, +got):\n%s", diff) + } + }) + } +} + +func TestAUMWeight(t *testing.T) { + var fakeKeyID [blake2s.Size]byte + testingRand(t, 1).Read(fakeKeyID[:]) + + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + pub, _ = testingKey25519(t, 2) + key2 := Key{Kind: Key25519, Public: pub, Votes: 2} + + tcs := []struct { + Name string + AUM AUM + State State + Want uint + }{ + { + "Empty", + AUM{}, + State{}, + 0, + }, + { + "Key unknown", + AUM{ + Signatures: []tkatype.Signature{{KeyID: fakeKeyID[:]}}, + }, + State{}, + 0, + }, + { + "Unary key", + AUM{ + Signatures: []tkatype.Signature{{KeyID: key.MustID()}}, + }, + State{ + Keys: []Key{key}, + }, + 2, + }, + { + "Multiple keys", + AUM{ + Signatures: []tkatype.Signature{{KeyID: key.MustID()}, {KeyID: key2.MustID()}}, + }, + State{ + Keys: []Key{key, key2}, + }, + 4, + }, + { + "Double use", + AUM{ + Signatures: []tkatype.Signature{{KeyID: key.MustID()}, {KeyID: key.MustID()}}, + }, + State{ + Keys: []Key{key}, + }, + 2, + }, + } + + for _, tc := range tcs { + t.Run(tc.Name, func(t *testing.T) { + got := tc.AUM.Weight(tc.State) + if got != tc.Want { + t.Errorf("Weight() = %d, want %d", got, tc.Want) + } + }) + } +} + +func TestAUMHashes(t *testing.T) { + // .Hash(): a hash over everything. + // .SigHash(): a hash over everything except the signatures. + // The signatures are over a hash of the AUM, so + // using SigHash() breaks this circularity. + + aum := AUM{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519}} + sigHash1 := aum.SigHash() + aumHash1 := aum.Hash() + + aum.Signatures = []tkatype.Signature{{KeyID: []byte{1, 2, 3, 4}}} + sigHash2 := aum.SigHash() + aumHash2 := aum.Hash() + if len(aum.Signatures) != 1 { + t.Error("signature was removed by one of the hash functions") + } + + if !bytes.Equal(sigHash1[:], sigHash1[:]) { + t.Errorf("signature hash dependent on signatures!\n\t1 = %x\n\t2 = %x", sigHash1, sigHash2) + } + if bytes.Equal(aumHash1[:], aumHash2[:]) { + t.Error("aum hash didnt change") + } +} diff --git a/tka/builder.go b/tka/builder.go index 19cd340f0..c14ba2330 100644 --- a/tka/builder.go +++ b/tka/builder.go @@ -1,180 +1,180 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "fmt" - "os" - - "tailscale.com/types/tkatype" -) - -// Types implementing Signer can sign update messages. -type Signer interface { - // SignAUM returns signatures for the AUM encoded by the given AUMSigHash. - SignAUM(tkatype.AUMSigHash) ([]tkatype.Signature, error) -} - -// UpdateBuilder implements a builder for changes to the tailnet -// key authority. -// -// Finalize must be called to compute the update messages, which -// must then be applied to all Authority objects using Inform(). -type UpdateBuilder struct { - a *Authority - signer Signer - - state State - parent AUMHash - - out []AUM -} - -func (b *UpdateBuilder) mkUpdate(update AUM) error { - prevHash := make([]byte, len(b.parent)) - copy(prevHash, b.parent[:]) - update.PrevAUMHash = prevHash - - if b.signer != nil { - sigs, err := b.signer.SignAUM(update.SigHash()) - if err != nil { - return fmt.Errorf("signing failed: %v", err) - } - update.Signatures = append(update.Signatures, sigs...) - } - if err := update.StaticValidate(); err != nil { - return fmt.Errorf("generated update was invalid: %v", err) - } - state, err := b.state.applyVerifiedAUM(update) - if err != nil { - return fmt.Errorf("update cannot be applied: %v", err) - } - - b.state = state - b.parent = update.Hash() - b.out = append(b.out, update) - return nil -} - -// AddKey adds a new key to the authority. -func (b *UpdateBuilder) AddKey(key Key) error { - keyID, err := key.ID() - if err != nil { - return err - } - - if _, err := b.state.GetKey(keyID); err == nil { - return fmt.Errorf("cannot add key %v: already exists", key) - } - return b.mkUpdate(AUM{MessageKind: AUMAddKey, Key: &key}) -} - -// RemoveKey removes a key from the authority. -func (b *UpdateBuilder) RemoveKey(keyID tkatype.KeyID) error { - if _, err := b.state.GetKey(keyID); err != nil { - return fmt.Errorf("failed reading key %x: %v", keyID, err) - } - return b.mkUpdate(AUM{MessageKind: AUMRemoveKey, KeyID: keyID}) -} - -// SetKeyVote updates the number of votes of an existing key. -func (b *UpdateBuilder) SetKeyVote(keyID tkatype.KeyID, votes uint) error { - if _, err := b.state.GetKey(keyID); err != nil { - return fmt.Errorf("failed reading key %x: %v", keyID, err) - } - return b.mkUpdate(AUM{MessageKind: AUMUpdateKey, Votes: &votes, KeyID: keyID}) -} - -// SetKeyMeta updates key-value metadata stored against an existing key. -// -// TODO(tom): Provide an API to update specific values rather than the whole -// map. -func (b *UpdateBuilder) SetKeyMeta(keyID tkatype.KeyID, meta map[string]string) error { - if _, err := b.state.GetKey(keyID); err != nil { - return fmt.Errorf("failed reading key %x: %v", keyID, err) - } - return b.mkUpdate(AUM{MessageKind: AUMUpdateKey, Meta: meta, KeyID: keyID}) -} - -func (b *UpdateBuilder) generateCheckpoint() error { - // Compute the checkpoint state. - state := b.a.state - for i, update := range b.out { - var err error - if state, err = state.applyVerifiedAUM(update); err != nil { - return fmt.Errorf("applying update %d: %v", i, err) - } - } - - // Checkpoints cant specify a parent AUM. - state.LastAUMHash = nil - return b.mkUpdate(AUM{MessageKind: AUMCheckpoint, State: &state}) -} - -// checkpointEvery sets how often a checkpoint AUM should be generated. -const checkpointEvery = 50 - -// Finalize returns the set of update message to actuate the update. -func (b *UpdateBuilder) Finalize(storage Chonk) ([]AUM, error) { - var ( - needCheckpoint bool = true - cursor AUMHash = b.a.Head() - ) - for i := len(b.out); i < checkpointEvery; i++ { - aum, err := storage.AUM(cursor) - if err != nil { - if err == os.ErrNotExist { - // The available chain is shorter than the interval to checkpoint at. - needCheckpoint = false - break - } - return nil, fmt.Errorf("reading AUM: %v", err) - } - - if aum.MessageKind == AUMCheckpoint { - needCheckpoint = false - break - } - - parent, hasParent := aum.Parent() - if !hasParent { - // We've hit the genesis update, so the chain is shorter than the interval to checkpoint at. - needCheckpoint = false - break - } - cursor = parent - } - - if needCheckpoint { - if err := b.generateCheckpoint(); err != nil { - return nil, fmt.Errorf("generating checkpoint: %v", err) - } - } - - // Check no AUMs were applied in the meantime - if len(b.out) > 0 { - if parent, _ := b.out[0].Parent(); parent != b.a.Head() { - return nil, fmt.Errorf("updates no longer apply to head: based on %x but head is %x", parent, b.a.Head()) - } - } - return b.out, nil -} - -// NewUpdater returns a builder you can use to make changes to -// the tailnet key authority. -// -// The provided signer function, if non-nil, is called with each update -// to compute and apply signatures. -// -// Updates are specified by calling methods on the returned UpdatedBuilder. -// Call Finalize() when you are done to obtain the specific update messages -// which actuate the changes. -func (a *Authority) NewUpdater(signer Signer) *UpdateBuilder { - return &UpdateBuilder{ - a: a, - signer: signer, - parent: a.Head(), - state: a.state, - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "fmt" + "os" + + "tailscale.com/types/tkatype" +) + +// Types implementing Signer can sign update messages. +type Signer interface { + // SignAUM returns signatures for the AUM encoded by the given AUMSigHash. + SignAUM(tkatype.AUMSigHash) ([]tkatype.Signature, error) +} + +// UpdateBuilder implements a builder for changes to the tailnet +// key authority. +// +// Finalize must be called to compute the update messages, which +// must then be applied to all Authority objects using Inform(). +type UpdateBuilder struct { + a *Authority + signer Signer + + state State + parent AUMHash + + out []AUM +} + +func (b *UpdateBuilder) mkUpdate(update AUM) error { + prevHash := make([]byte, len(b.parent)) + copy(prevHash, b.parent[:]) + update.PrevAUMHash = prevHash + + if b.signer != nil { + sigs, err := b.signer.SignAUM(update.SigHash()) + if err != nil { + return fmt.Errorf("signing failed: %v", err) + } + update.Signatures = append(update.Signatures, sigs...) + } + if err := update.StaticValidate(); err != nil { + return fmt.Errorf("generated update was invalid: %v", err) + } + state, err := b.state.applyVerifiedAUM(update) + if err != nil { + return fmt.Errorf("update cannot be applied: %v", err) + } + + b.state = state + b.parent = update.Hash() + b.out = append(b.out, update) + return nil +} + +// AddKey adds a new key to the authority. +func (b *UpdateBuilder) AddKey(key Key) error { + keyID, err := key.ID() + if err != nil { + return err + } + + if _, err := b.state.GetKey(keyID); err == nil { + return fmt.Errorf("cannot add key %v: already exists", key) + } + return b.mkUpdate(AUM{MessageKind: AUMAddKey, Key: &key}) +} + +// RemoveKey removes a key from the authority. +func (b *UpdateBuilder) RemoveKey(keyID tkatype.KeyID) error { + if _, err := b.state.GetKey(keyID); err != nil { + return fmt.Errorf("failed reading key %x: %v", keyID, err) + } + return b.mkUpdate(AUM{MessageKind: AUMRemoveKey, KeyID: keyID}) +} + +// SetKeyVote updates the number of votes of an existing key. +func (b *UpdateBuilder) SetKeyVote(keyID tkatype.KeyID, votes uint) error { + if _, err := b.state.GetKey(keyID); err != nil { + return fmt.Errorf("failed reading key %x: %v", keyID, err) + } + return b.mkUpdate(AUM{MessageKind: AUMUpdateKey, Votes: &votes, KeyID: keyID}) +} + +// SetKeyMeta updates key-value metadata stored against an existing key. +// +// TODO(tom): Provide an API to update specific values rather than the whole +// map. +func (b *UpdateBuilder) SetKeyMeta(keyID tkatype.KeyID, meta map[string]string) error { + if _, err := b.state.GetKey(keyID); err != nil { + return fmt.Errorf("failed reading key %x: %v", keyID, err) + } + return b.mkUpdate(AUM{MessageKind: AUMUpdateKey, Meta: meta, KeyID: keyID}) +} + +func (b *UpdateBuilder) generateCheckpoint() error { + // Compute the checkpoint state. + state := b.a.state + for i, update := range b.out { + var err error + if state, err = state.applyVerifiedAUM(update); err != nil { + return fmt.Errorf("applying update %d: %v", i, err) + } + } + + // Checkpoints cant specify a parent AUM. + state.LastAUMHash = nil + return b.mkUpdate(AUM{MessageKind: AUMCheckpoint, State: &state}) +} + +// checkpointEvery sets how often a checkpoint AUM should be generated. +const checkpointEvery = 50 + +// Finalize returns the set of update message to actuate the update. +func (b *UpdateBuilder) Finalize(storage Chonk) ([]AUM, error) { + var ( + needCheckpoint bool = true + cursor AUMHash = b.a.Head() + ) + for i := len(b.out); i < checkpointEvery; i++ { + aum, err := storage.AUM(cursor) + if err != nil { + if err == os.ErrNotExist { + // The available chain is shorter than the interval to checkpoint at. + needCheckpoint = false + break + } + return nil, fmt.Errorf("reading AUM: %v", err) + } + + if aum.MessageKind == AUMCheckpoint { + needCheckpoint = false + break + } + + parent, hasParent := aum.Parent() + if !hasParent { + // We've hit the genesis update, so the chain is shorter than the interval to checkpoint at. + needCheckpoint = false + break + } + cursor = parent + } + + if needCheckpoint { + if err := b.generateCheckpoint(); err != nil { + return nil, fmt.Errorf("generating checkpoint: %v", err) + } + } + + // Check no AUMs were applied in the meantime + if len(b.out) > 0 { + if parent, _ := b.out[0].Parent(); parent != b.a.Head() { + return nil, fmt.Errorf("updates no longer apply to head: based on %x but head is %x", parent, b.a.Head()) + } + } + return b.out, nil +} + +// NewUpdater returns a builder you can use to make changes to +// the tailnet key authority. +// +// The provided signer function, if non-nil, is called with each update +// to compute and apply signatures. +// +// Updates are specified by calling methods on the returned UpdatedBuilder. +// Call Finalize() when you are done to obtain the specific update messages +// which actuate the changes. +func (a *Authority) NewUpdater(signer Signer) *UpdateBuilder { + return &UpdateBuilder{ + a: a, + signer: signer, + parent: a.Head(), + state: a.state, + } +} diff --git a/tka/builder_test.go b/tka/builder_test.go index 758fb170c..666af9ad0 100644 --- a/tka/builder_test.go +++ b/tka/builder_test.go @@ -1,270 +1,270 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "crypto/ed25519" - "testing" - - "github.com/google/go-cmp/cmp" - "tailscale.com/types/tkatype" -) - -type signer25519 ed25519.PrivateKey - -func (s signer25519) SignAUM(sigHash tkatype.AUMSigHash) ([]tkatype.Signature, error) { - priv := ed25519.PrivateKey(s) - key := Key{Kind: Key25519, Public: priv.Public().(ed25519.PublicKey)} - - return []tkatype.Signature{{ - KeyID: key.MustID(), - Signature: ed25519.Sign(priv, sigHash[:]), - }}, nil -} - -func TestAuthorityBuilderAddKey(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - storage := &Mem{} - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - pub2, _ := testingKey25519(t, 2) - key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - - b := a.NewUpdater(signer25519(priv)) - if err := b.AddKey(key2); err != nil { - t.Fatalf("AddKey(%v) failed: %v", key2, err) - } - updates, err := b.Finalize(storage) - if err != nil { - t.Fatalf("Finalize() failed: %v", err) - } - - // See if the update is valid by applying it to the authority - // + checking if the new key is there. - if err := a.Inform(storage, updates); err != nil { - t.Fatalf("could not apply generated updates: %v", err) - } - if _, err := a.state.GetKey(key2.MustID()); err != nil { - t.Errorf("could not read new key: %v", err) - } -} - -func TestAuthorityBuilderRemoveKey(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - pub2, _ := testingKey25519(t, 2) - key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - - storage := &Mem{} - a, _, err := Create(storage, State{ - Keys: []Key{key, key2}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - b := a.NewUpdater(signer25519(priv)) - if err := b.RemoveKey(key2.MustID()); err != nil { - t.Fatalf("RemoveKey(%v) failed: %v", key2, err) - } - updates, err := b.Finalize(storage) - if err != nil { - t.Fatalf("Finalize() failed: %v", err) - } - - // See if the update is valid by applying it to the authority - // + checking if the key has been removed. - if err := a.Inform(storage, updates); err != nil { - t.Fatalf("could not apply generated updates: %v", err) - } - if _, err := a.state.GetKey(key2.MustID()); err != ErrNoSuchKey { - t.Errorf("GetKey(key2).err = %v, want %v", err, ErrNoSuchKey) - } -} - -func TestAuthorityBuilderSetKeyVote(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - storage := &Mem{} - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - b := a.NewUpdater(signer25519(priv)) - if err := b.SetKeyVote(key.MustID(), 5); err != nil { - t.Fatalf("SetKeyVote(%v) failed: %v", key.MustID(), err) - } - updates, err := b.Finalize(storage) - if err != nil { - t.Fatalf("Finalize() failed: %v", err) - } - - // See if the update is valid by applying it to the authority - // + checking if the update is there. - if err := a.Inform(storage, updates); err != nil { - t.Fatalf("could not apply generated updates: %v", err) - } - k, err := a.state.GetKey(key.MustID()) - if err != nil { - t.Fatal(err) - } - if got, want := k.Votes, uint(5); got != want { - t.Errorf("key.Votes = %d, want %d", got, want) - } -} - -func TestAuthorityBuilderSetKeyMeta(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2, Meta: map[string]string{"a": "b"}} - - storage := &Mem{} - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - b := a.NewUpdater(signer25519(priv)) - if err := b.SetKeyMeta(key.MustID(), map[string]string{"b": "c"}); err != nil { - t.Fatalf("SetKeyMeta(%v) failed: %v", key, err) - } - updates, err := b.Finalize(storage) - if err != nil { - t.Fatalf("Finalize() failed: %v", err) - } - - // See if the update is valid by applying it to the authority - // + checking if the update is there. - if err := a.Inform(storage, updates); err != nil { - t.Fatalf("could not apply generated updates: %v", err) - } - k, err := a.state.GetKey(key.MustID()) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(map[string]string{"b": "c"}, k.Meta); diff != "" { - t.Errorf("updated meta differs (-want, +got):\n%s", diff) - } -} - -func TestAuthorityBuilderMultiple(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - storage := &Mem{} - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - pub2, _ := testingKey25519(t, 2) - key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - - b := a.NewUpdater(signer25519(priv)) - if err := b.AddKey(key2); err != nil { - t.Fatalf("AddKey(%v) failed: %v", key2, err) - } - if err := b.SetKeyVote(key2.MustID(), 42); err != nil { - t.Fatalf("SetKeyVote(%v) failed: %v", key2, err) - } - if err := b.RemoveKey(key.MustID()); err != nil { - t.Fatalf("RemoveKey(%v) failed: %v", key, err) - } - updates, err := b.Finalize(storage) - if err != nil { - t.Fatalf("Finalize() failed: %v", err) - } - - // See if the update is valid by applying it to the authority - // + checking if the update is there. - if err := a.Inform(storage, updates); err != nil { - t.Fatalf("could not apply generated updates: %v", err) - } - k, err := a.state.GetKey(key2.MustID()) - if err != nil { - t.Fatal(err) - } - if got, want := k.Votes, uint(42); got != want { - t.Errorf("key.Votes = %d, want %d", got, want) - } - if _, err := a.state.GetKey(key.MustID()); err != ErrNoSuchKey { - t.Errorf("GetKey(key).err = %v, want %v", err, ErrNoSuchKey) - } -} - -func TestAuthorityBuilderCheckpointsAfterXUpdates(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - storage := &Mem{} - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - for i := 0; i <= checkpointEvery; i++ { - pub2, _ := testingKey25519(t, int64(i+2)) - key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - - b := a.NewUpdater(signer25519(priv)) - if err := b.AddKey(key2); err != nil { - t.Fatalf("AddKey(%v) failed: %v", key2, err) - } - updates, err := b.Finalize(storage) - if err != nil { - t.Fatalf("Finalize() failed: %v", err) - } - // See if the update is valid by applying it to the authority - // + checking if the new key is there. - if err := a.Inform(storage, updates); err != nil { - t.Fatalf("could not apply generated updates: %v", err) - } - if _, err := a.state.GetKey(key2.MustID()); err != nil { - t.Fatal(err) - } - - wantKind := AUMAddKey - if i == checkpointEvery-1 { // Genesis + 49 updates == 50 (the value of checkpointEvery) - wantKind = AUMCheckpoint - } - lastAUM, err := storage.AUM(a.Head()) - if err != nil { - t.Fatal(err) - } - if lastAUM.MessageKind != wantKind { - t.Errorf("[%d] HeadAUM.MessageKind = %v, want %v", i, lastAUM.MessageKind, wantKind) - } - } - - // Try starting an authority just based on storage. - a2, err := Open(storage) - if err != nil { - t.Fatalf("Failed to open from stored AUMs: %v", err) - } - if a.Head() != a2.Head() { - t.Errorf("stored and computed HEAD differ: got %v, want %v", a2.Head(), a.Head()) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "crypto/ed25519" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/types/tkatype" +) + +type signer25519 ed25519.PrivateKey + +func (s signer25519) SignAUM(sigHash tkatype.AUMSigHash) ([]tkatype.Signature, error) { + priv := ed25519.PrivateKey(s) + key := Key{Kind: Key25519, Public: priv.Public().(ed25519.PublicKey)} + + return []tkatype.Signature{{ + KeyID: key.MustID(), + Signature: ed25519.Sign(priv, sigHash[:]), + }}, nil +} + +func TestAuthorityBuilderAddKey(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + storage := &Mem{} + a, _, err := Create(storage, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + pub2, _ := testingKey25519(t, 2) + key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + + b := a.NewUpdater(signer25519(priv)) + if err := b.AddKey(key2); err != nil { + t.Fatalf("AddKey(%v) failed: %v", key2, err) + } + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + + // See if the update is valid by applying it to the authority + // + checking if the new key is there. + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + if _, err := a.state.GetKey(key2.MustID()); err != nil { + t.Errorf("could not read new key: %v", err) + } +} + +func TestAuthorityBuilderRemoveKey(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + pub2, _ := testingKey25519(t, 2) + key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + + storage := &Mem{} + a, _, err := Create(storage, State{ + Keys: []Key{key, key2}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + b := a.NewUpdater(signer25519(priv)) + if err := b.RemoveKey(key2.MustID()); err != nil { + t.Fatalf("RemoveKey(%v) failed: %v", key2, err) + } + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + + // See if the update is valid by applying it to the authority + // + checking if the key has been removed. + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + if _, err := a.state.GetKey(key2.MustID()); err != ErrNoSuchKey { + t.Errorf("GetKey(key2).err = %v, want %v", err, ErrNoSuchKey) + } +} + +func TestAuthorityBuilderSetKeyVote(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + storage := &Mem{} + a, _, err := Create(storage, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + b := a.NewUpdater(signer25519(priv)) + if err := b.SetKeyVote(key.MustID(), 5); err != nil { + t.Fatalf("SetKeyVote(%v) failed: %v", key.MustID(), err) + } + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + + // See if the update is valid by applying it to the authority + // + checking if the update is there. + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + k, err := a.state.GetKey(key.MustID()) + if err != nil { + t.Fatal(err) + } + if got, want := k.Votes, uint(5); got != want { + t.Errorf("key.Votes = %d, want %d", got, want) + } +} + +func TestAuthorityBuilderSetKeyMeta(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2, Meta: map[string]string{"a": "b"}} + + storage := &Mem{} + a, _, err := Create(storage, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + b := a.NewUpdater(signer25519(priv)) + if err := b.SetKeyMeta(key.MustID(), map[string]string{"b": "c"}); err != nil { + t.Fatalf("SetKeyMeta(%v) failed: %v", key, err) + } + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + + // See if the update is valid by applying it to the authority + // + checking if the update is there. + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + k, err := a.state.GetKey(key.MustID()) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(map[string]string{"b": "c"}, k.Meta); diff != "" { + t.Errorf("updated meta differs (-want, +got):\n%s", diff) + } +} + +func TestAuthorityBuilderMultiple(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + storage := &Mem{} + a, _, err := Create(storage, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + pub2, _ := testingKey25519(t, 2) + key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + + b := a.NewUpdater(signer25519(priv)) + if err := b.AddKey(key2); err != nil { + t.Fatalf("AddKey(%v) failed: %v", key2, err) + } + if err := b.SetKeyVote(key2.MustID(), 42); err != nil { + t.Fatalf("SetKeyVote(%v) failed: %v", key2, err) + } + if err := b.RemoveKey(key.MustID()); err != nil { + t.Fatalf("RemoveKey(%v) failed: %v", key, err) + } + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + + // See if the update is valid by applying it to the authority + // + checking if the update is there. + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + k, err := a.state.GetKey(key2.MustID()) + if err != nil { + t.Fatal(err) + } + if got, want := k.Votes, uint(42); got != want { + t.Errorf("key.Votes = %d, want %d", got, want) + } + if _, err := a.state.GetKey(key.MustID()); err != ErrNoSuchKey { + t.Errorf("GetKey(key).err = %v, want %v", err, ErrNoSuchKey) + } +} + +func TestAuthorityBuilderCheckpointsAfterXUpdates(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + storage := &Mem{} + a, _, err := Create(storage, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + for i := 0; i <= checkpointEvery; i++ { + pub2, _ := testingKey25519(t, int64(i+2)) + key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + + b := a.NewUpdater(signer25519(priv)) + if err := b.AddKey(key2); err != nil { + t.Fatalf("AddKey(%v) failed: %v", key2, err) + } + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + // See if the update is valid by applying it to the authority + // + checking if the new key is there. + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + if _, err := a.state.GetKey(key2.MustID()); err != nil { + t.Fatal(err) + } + + wantKind := AUMAddKey + if i == checkpointEvery-1 { // Genesis + 49 updates == 50 (the value of checkpointEvery) + wantKind = AUMCheckpoint + } + lastAUM, err := storage.AUM(a.Head()) + if err != nil { + t.Fatal(err) + } + if lastAUM.MessageKind != wantKind { + t.Errorf("[%d] HeadAUM.MessageKind = %v, want %v", i, lastAUM.MessageKind, wantKind) + } + } + + // Try starting an authority just based on storage. + a2, err := Open(storage) + if err != nil { + t.Fatalf("Failed to open from stored AUMs: %v", err) + } + if a.Head() != a2.Head() { + t.Errorf("stored and computed HEAD differ: got %v, want %v", a2.Head(), a.Head()) + } +} diff --git a/tka/deeplink.go b/tka/deeplink.go index 97bcd664b..5cf24fc5c 100644 --- a/tka/deeplink.go +++ b/tka/deeplink.go @@ -1,221 +1,221 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "crypto/hmac" - "crypto/sha256" - "encoding/binary" - "encoding/hex" - "fmt" - "net/url" - "strings" -) - -const ( - DeeplinkTailscaleURLScheme = "tailscale" - DeeplinkCommandSign = "sign-device" -) - -// generateHMAC computes a SHA-256 HMAC for the concatenation of components, -// using the Authority stateID as secret. -func (a *Authority) generateHMAC(params NewDeeplinkParams) []byte { - stateID, _ := a.StateIDs() - - key := make([]byte, 8) - binary.LittleEndian.PutUint64(key, stateID) - mac := hmac.New(sha256.New, key) - mac.Write([]byte(params.NodeKey)) - mac.Write([]byte(params.TLPub)) - mac.Write([]byte(params.DeviceName)) - mac.Write([]byte(params.OSName)) - mac.Write([]byte(params.LoginName)) - return mac.Sum(nil) -} - -type NewDeeplinkParams struct { - NodeKey string - TLPub string - DeviceName string - OSName string - LoginName string -} - -// NewDeeplink creates a signed deeplink using the authority's stateID as a -// secret. This deeplink can then be validated by ValidateDeeplink. -func (a *Authority) NewDeeplink(params NewDeeplinkParams) (string, error) { - if params.NodeKey == "" || !strings.HasPrefix(params.NodeKey, "nodekey:") { - return "", fmt.Errorf("invalid node key %q", params.NodeKey) - } - if params.TLPub == "" || !strings.HasPrefix(params.TLPub, "tlpub:") { - return "", fmt.Errorf("invalid tlpub %q", params.TLPub) - } - if params.DeviceName == "" { - return "", fmt.Errorf("invalid device name %q", params.DeviceName) - } - if params.OSName == "" { - return "", fmt.Errorf("invalid os name %q", params.OSName) - } - if params.LoginName == "" { - return "", fmt.Errorf("invalid login name %q", params.LoginName) - } - - u := url.URL{ - Scheme: DeeplinkTailscaleURLScheme, - Host: DeeplinkCommandSign, - Path: "/v1/", - } - v := url.Values{} - v.Set("nk", params.NodeKey) - v.Set("tp", params.TLPub) - v.Set("dn", params.DeviceName) - v.Set("os", params.OSName) - v.Set("em", params.LoginName) - - hmac := a.generateHMAC(params) - v.Set("hm", hex.EncodeToString(hmac)) - - u.RawQuery = v.Encode() - return u.String(), nil -} - -type DeeplinkValidationResult struct { - IsValid bool - Error string - Version uint8 - NodeKey string - TLPub string - DeviceName string - OSName string - EmailAddress string -} - -// ValidateDeeplink validates a device signing deeplink using the authority's stateID. -// The input urlString follows this structure: -// -// tailscale://sign-device/v1/?nk=xxx&tp=xxx&dn=xxx&os=xxx&em=xxx&hm=xxx -// -// where: -// - "nk" is the nodekey of the node being signed -// - "tp" is the tailnet lock public key -// - "dn" is the name of the node -// - "os" is the operating system of the node -// - "em" is the email address associated with the node -// - "hm" is a SHA-256 HMAC computed over the concatenation of the above fields, encoded as a hex string -func (a *Authority) ValidateDeeplink(urlString string) DeeplinkValidationResult { - parsedUrl, err := url.Parse(urlString) - if err != nil { - return DeeplinkValidationResult{ - IsValid: false, - Error: err.Error(), - } - } - - if parsedUrl.Scheme != DeeplinkTailscaleURLScheme { - return DeeplinkValidationResult{ - IsValid: false, - Error: fmt.Sprintf("unhandled scheme %s, expected %s", parsedUrl.Scheme, DeeplinkTailscaleURLScheme), - } - } - - if parsedUrl.Host != DeeplinkCommandSign { - return DeeplinkValidationResult{ - IsValid: false, - Error: fmt.Sprintf("unhandled host %s, expected %s", parsedUrl.Host, DeeplinkCommandSign), - } - } - - path := parsedUrl.EscapedPath() - pathComponents := strings.Split(path, "/") - if len(pathComponents) != 3 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "invalid path components number found", - } - } - - if pathComponents[1] != "v1" { - return DeeplinkValidationResult{ - IsValid: false, - Error: fmt.Sprintf("expected v1 deeplink version, found something else: %s", pathComponents[1]), - } - } - - nodeKey := parsedUrl.Query().Get("nk") - if len(nodeKey) == 0 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "missing nk (NodeKey) query parameter", - } - } - - tlPub := parsedUrl.Query().Get("tp") - if len(tlPub) == 0 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "missing tp (TLPub) query parameter", - } - } - - deviceName := parsedUrl.Query().Get("dn") - if len(deviceName) == 0 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "missing dn (DeviceName) query parameter", - } - } - - osName := parsedUrl.Query().Get("os") - if len(deviceName) == 0 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "missing os (OSName) query parameter", - } - } - - emailAddress := parsedUrl.Query().Get("em") - if len(emailAddress) == 0 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "missing em (EmailAddress) query parameter", - } - } - - hmacString := parsedUrl.Query().Get("hm") - if len(hmacString) == 0 { - return DeeplinkValidationResult{ - IsValid: false, - Error: "missing hm (HMAC) query parameter", - } - } - - computedHMAC := a.generateHMAC(NewDeeplinkParams{ - NodeKey: nodeKey, - TLPub: tlPub, - DeviceName: deviceName, - OSName: osName, - LoginName: emailAddress, - }) - - hmacHexBytes, err := hex.DecodeString(hmacString) - if err != nil { - return DeeplinkValidationResult{IsValid: false, Error: "could not hex-decode hmac"} - } - - if !hmac.Equal(computedHMAC, hmacHexBytes) { - return DeeplinkValidationResult{ - IsValid: false, - Error: "hmac authentication failed", - } - } - - return DeeplinkValidationResult{ - IsValid: true, - NodeKey: nodeKey, - TLPub: tlPub, - DeviceName: deviceName, - OSName: osName, - EmailAddress: emailAddress, - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "fmt" + "net/url" + "strings" +) + +const ( + DeeplinkTailscaleURLScheme = "tailscale" + DeeplinkCommandSign = "sign-device" +) + +// generateHMAC computes a SHA-256 HMAC for the concatenation of components, +// using the Authority stateID as secret. +func (a *Authority) generateHMAC(params NewDeeplinkParams) []byte { + stateID, _ := a.StateIDs() + + key := make([]byte, 8) + binary.LittleEndian.PutUint64(key, stateID) + mac := hmac.New(sha256.New, key) + mac.Write([]byte(params.NodeKey)) + mac.Write([]byte(params.TLPub)) + mac.Write([]byte(params.DeviceName)) + mac.Write([]byte(params.OSName)) + mac.Write([]byte(params.LoginName)) + return mac.Sum(nil) +} + +type NewDeeplinkParams struct { + NodeKey string + TLPub string + DeviceName string + OSName string + LoginName string +} + +// NewDeeplink creates a signed deeplink using the authority's stateID as a +// secret. This deeplink can then be validated by ValidateDeeplink. +func (a *Authority) NewDeeplink(params NewDeeplinkParams) (string, error) { + if params.NodeKey == "" || !strings.HasPrefix(params.NodeKey, "nodekey:") { + return "", fmt.Errorf("invalid node key %q", params.NodeKey) + } + if params.TLPub == "" || !strings.HasPrefix(params.TLPub, "tlpub:") { + return "", fmt.Errorf("invalid tlpub %q", params.TLPub) + } + if params.DeviceName == "" { + return "", fmt.Errorf("invalid device name %q", params.DeviceName) + } + if params.OSName == "" { + return "", fmt.Errorf("invalid os name %q", params.OSName) + } + if params.LoginName == "" { + return "", fmt.Errorf("invalid login name %q", params.LoginName) + } + + u := url.URL{ + Scheme: DeeplinkTailscaleURLScheme, + Host: DeeplinkCommandSign, + Path: "/v1/", + } + v := url.Values{} + v.Set("nk", params.NodeKey) + v.Set("tp", params.TLPub) + v.Set("dn", params.DeviceName) + v.Set("os", params.OSName) + v.Set("em", params.LoginName) + + hmac := a.generateHMAC(params) + v.Set("hm", hex.EncodeToString(hmac)) + + u.RawQuery = v.Encode() + return u.String(), nil +} + +type DeeplinkValidationResult struct { + IsValid bool + Error string + Version uint8 + NodeKey string + TLPub string + DeviceName string + OSName string + EmailAddress string +} + +// ValidateDeeplink validates a device signing deeplink using the authority's stateID. +// The input urlString follows this structure: +// +// tailscale://sign-device/v1/?nk=xxx&tp=xxx&dn=xxx&os=xxx&em=xxx&hm=xxx +// +// where: +// - "nk" is the nodekey of the node being signed +// - "tp" is the tailnet lock public key +// - "dn" is the name of the node +// - "os" is the operating system of the node +// - "em" is the email address associated with the node +// - "hm" is a SHA-256 HMAC computed over the concatenation of the above fields, encoded as a hex string +func (a *Authority) ValidateDeeplink(urlString string) DeeplinkValidationResult { + parsedUrl, err := url.Parse(urlString) + if err != nil { + return DeeplinkValidationResult{ + IsValid: false, + Error: err.Error(), + } + } + + if parsedUrl.Scheme != DeeplinkTailscaleURLScheme { + return DeeplinkValidationResult{ + IsValid: false, + Error: fmt.Sprintf("unhandled scheme %s, expected %s", parsedUrl.Scheme, DeeplinkTailscaleURLScheme), + } + } + + if parsedUrl.Host != DeeplinkCommandSign { + return DeeplinkValidationResult{ + IsValid: false, + Error: fmt.Sprintf("unhandled host %s, expected %s", parsedUrl.Host, DeeplinkCommandSign), + } + } + + path := parsedUrl.EscapedPath() + pathComponents := strings.Split(path, "/") + if len(pathComponents) != 3 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "invalid path components number found", + } + } + + if pathComponents[1] != "v1" { + return DeeplinkValidationResult{ + IsValid: false, + Error: fmt.Sprintf("expected v1 deeplink version, found something else: %s", pathComponents[1]), + } + } + + nodeKey := parsedUrl.Query().Get("nk") + if len(nodeKey) == 0 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "missing nk (NodeKey) query parameter", + } + } + + tlPub := parsedUrl.Query().Get("tp") + if len(tlPub) == 0 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "missing tp (TLPub) query parameter", + } + } + + deviceName := parsedUrl.Query().Get("dn") + if len(deviceName) == 0 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "missing dn (DeviceName) query parameter", + } + } + + osName := parsedUrl.Query().Get("os") + if len(deviceName) == 0 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "missing os (OSName) query parameter", + } + } + + emailAddress := parsedUrl.Query().Get("em") + if len(emailAddress) == 0 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "missing em (EmailAddress) query parameter", + } + } + + hmacString := parsedUrl.Query().Get("hm") + if len(hmacString) == 0 { + return DeeplinkValidationResult{ + IsValid: false, + Error: "missing hm (HMAC) query parameter", + } + } + + computedHMAC := a.generateHMAC(NewDeeplinkParams{ + NodeKey: nodeKey, + TLPub: tlPub, + DeviceName: deviceName, + OSName: osName, + LoginName: emailAddress, + }) + + hmacHexBytes, err := hex.DecodeString(hmacString) + if err != nil { + return DeeplinkValidationResult{IsValid: false, Error: "could not hex-decode hmac"} + } + + if !hmac.Equal(computedHMAC, hmacHexBytes) { + return DeeplinkValidationResult{ + IsValid: false, + Error: "hmac authentication failed", + } + } + + return DeeplinkValidationResult{ + IsValid: true, + NodeKey: nodeKey, + TLPub: tlPub, + DeviceName: deviceName, + OSName: osName, + EmailAddress: emailAddress, + } +} diff --git a/tka/deeplink_test.go b/tka/deeplink_test.go index 397cc6917..03523202f 100644 --- a/tka/deeplink_test.go +++ b/tka/deeplink_test.go @@ -1,52 +1,52 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "testing" -) - -func TestGenerateDeeplink(t *testing.T) { - pub, _ := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - c := newTestchain(t, ` - G1 -> L1 - - G1.template = genesis - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - ) - a, _ := Open(c.Chonk()) - - nodeKey := "nodekey:1234567890" - tlPub := "tlpub:1234567890" - deviceName := "Example Device" - osName := "iOS" - loginName := "insecure@example.com" - - deeplink, err := a.NewDeeplink(NewDeeplinkParams{ - NodeKey: nodeKey, - TLPub: tlPub, - DeviceName: deviceName, - OSName: osName, - LoginName: loginName, - }) - if err != nil { - t.Errorf("deeplink generation failed: %v", err) - } - - res := a.ValidateDeeplink(deeplink) - if !res.IsValid { - t.Errorf("deeplink validation failed: %s", res.Error) - } - if res.NodeKey != nodeKey { - t.Errorf("node key mismatch: %s != %s", res.NodeKey, nodeKey) - } - if res.TLPub != tlPub { - t.Errorf("tlpub mismatch: %s != %s", res.TLPub, tlPub) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "testing" +) + +func TestGenerateDeeplink(t *testing.T) { + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + c := newTestchain(t, ` + G1 -> L1 + + G1.template = genesis + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + ) + a, _ := Open(c.Chonk()) + + nodeKey := "nodekey:1234567890" + tlPub := "tlpub:1234567890" + deviceName := "Example Device" + osName := "iOS" + loginName := "insecure@example.com" + + deeplink, err := a.NewDeeplink(NewDeeplinkParams{ + NodeKey: nodeKey, + TLPub: tlPub, + DeviceName: deviceName, + OSName: osName, + LoginName: loginName, + }) + if err != nil { + t.Errorf("deeplink generation failed: %v", err) + } + + res := a.ValidateDeeplink(deeplink) + if !res.IsValid { + t.Errorf("deeplink validation failed: %s", res.Error) + } + if res.NodeKey != nodeKey { + t.Errorf("node key mismatch: %s != %s", res.NodeKey, nodeKey) + } + if res.TLPub != tlPub { + t.Errorf("tlpub mismatch: %s != %s", res.TLPub, tlPub) + } +} diff --git a/tka/key.go b/tka/key.go index 47218438d..07736795d 100644 --- a/tka/key.go +++ b/tka/key.go @@ -1,159 +1,159 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "crypto/ed25519" - "errors" - "fmt" - - "github.com/hdevalence/ed25519consensus" - "tailscale.com/types/tkatype" -) - -// KeyKind describes the different varieties of a Key. -type KeyKind uint8 - -// Valid KeyKind values. -const ( - KeyInvalid KeyKind = iota - Key25519 -) - -func (k KeyKind) String() string { - switch k { - case KeyInvalid: - return "invalid" - case Key25519: - return "25519" - default: - return fmt.Sprintf("Key?<%d>", int(k)) - } -} - -// Key describes the public components of a key known to network-lock. -type Key struct { - Kind KeyKind `cbor:"1,keyasint"` - - // Votes describes the weight applied to signatures using this key. - // Weighting is used to deterministically resolve branches in the AUM - // chain (i.e. forks, where two AUMs exist with the same parent). - Votes uint `cbor:"2,keyasint"` - - // Public encodes the public key of the key. For 25519 keys, - // this is simply the point on the curve representing the public - // key. - Public []byte `cbor:"3,keyasint"` - - // Meta describes arbitrary metadata about the key. This could be - // used to store the name of the key, for instance. - Meta map[string]string `cbor:"12,keyasint,omitempty"` -} - -// Clone makes an independent copy of Key. -// -// NOTE: There is a difference between a nil slice and an empty slice for encoding purposes, -// so an implementation of Clone() must take care to preserve this. -func (k Key) Clone() Key { - out := k - - if k.Public != nil { - out.Public = make([]byte, len(k.Public)) - copy(out.Public, k.Public) - } - - if k.Meta != nil { - out.Meta = make(map[string]string, len(k.Meta)) - for k, v := range k.Meta { - out.Meta[k] = v - } - } - - return out -} - -// MustID returns the KeyID of the key, panicking if an error is -// encountered. This must only be used for tests. -func (k Key) MustID() tkatype.KeyID { - id, err := k.ID() - if err != nil { - panic(err) - } - return id -} - -// ID returns the KeyID of the key. -func (k Key) ID() (tkatype.KeyID, error) { - switch k.Kind { - // Because 25519 public keys are so short, we just use the 32-byte - // public as their 'key ID'. - case Key25519: - return tkatype.KeyID(k.Public), nil - default: - return nil, fmt.Errorf("unknown key kind: %v", k.Kind) - } -} - -// Ed25519 returns the ed25519 public key encoded by Key. An error is -// returned for keys which do not represent ed25519 public keys. -func (k Key) Ed25519() (ed25519.PublicKey, error) { - switch k.Kind { - case Key25519: - return ed25519.PublicKey(k.Public), nil - default: - return nil, fmt.Errorf("key is of type %v, not ed25519", k.Kind) - } -} - -const maxMetaBytes = 512 - -func (k Key) StaticValidate() error { - if k.Votes > 4096 { - return fmt.Errorf("excessive key weight: %d > 4096", k.Votes) - } - if k.Votes == 0 { - return errors.New("key votes must be non-zero") - } - - // We have an arbitrary upper limit on the amount - // of metadata that can be associated with a key, so - // people don't start using it as a key-value store and - // causing pathological cases due to the number + size of - // AUMs. - var metaBytes uint - for k, v := range k.Meta { - metaBytes += uint(len(k) + len(v)) - } - if metaBytes > maxMetaBytes { - return fmt.Errorf("key metadata too big (%d > %d)", metaBytes, maxMetaBytes) - } - - switch k.Kind { - case Key25519: - default: - return fmt.Errorf("unrecognized key kind: %v", k.Kind) - } - return nil -} - -// Verify returns a nil error if the signature is valid over the -// provided AUM BLAKE2s digest, using the given key. -func signatureVerify(s *tkatype.Signature, aumDigest tkatype.AUMSigHash, key Key) error { - // NOTE(tom): Even if we can compute the public from the KeyID, - // its possible for the KeyID to be attacker-controlled - // so we should use the public contained in the state machine. - switch key.Kind { - case Key25519: - if len(key.Public) != ed25519.PublicKeySize { - return fmt.Errorf("ed25519 key has wrong length: %d", len(key.Public)) - } - if ed25519consensus.Verify(ed25519.PublicKey(key.Public), aumDigest[:], s.Signature) { - return nil - } - return errors.New("invalid signature") - - default: - return fmt.Errorf("unhandled key type: %v", key.Kind) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "crypto/ed25519" + "errors" + "fmt" + + "github.com/hdevalence/ed25519consensus" + "tailscale.com/types/tkatype" +) + +// KeyKind describes the different varieties of a Key. +type KeyKind uint8 + +// Valid KeyKind values. +const ( + KeyInvalid KeyKind = iota + Key25519 +) + +func (k KeyKind) String() string { + switch k { + case KeyInvalid: + return "invalid" + case Key25519: + return "25519" + default: + return fmt.Sprintf("Key?<%d>", int(k)) + } +} + +// Key describes the public components of a key known to network-lock. +type Key struct { + Kind KeyKind `cbor:"1,keyasint"` + + // Votes describes the weight applied to signatures using this key. + // Weighting is used to deterministically resolve branches in the AUM + // chain (i.e. forks, where two AUMs exist with the same parent). + Votes uint `cbor:"2,keyasint"` + + // Public encodes the public key of the key. For 25519 keys, + // this is simply the point on the curve representing the public + // key. + Public []byte `cbor:"3,keyasint"` + + // Meta describes arbitrary metadata about the key. This could be + // used to store the name of the key, for instance. + Meta map[string]string `cbor:"12,keyasint,omitempty"` +} + +// Clone makes an independent copy of Key. +// +// NOTE: There is a difference between a nil slice and an empty slice for encoding purposes, +// so an implementation of Clone() must take care to preserve this. +func (k Key) Clone() Key { + out := k + + if k.Public != nil { + out.Public = make([]byte, len(k.Public)) + copy(out.Public, k.Public) + } + + if k.Meta != nil { + out.Meta = make(map[string]string, len(k.Meta)) + for k, v := range k.Meta { + out.Meta[k] = v + } + } + + return out +} + +// MustID returns the KeyID of the key, panicking if an error is +// encountered. This must only be used for tests. +func (k Key) MustID() tkatype.KeyID { + id, err := k.ID() + if err != nil { + panic(err) + } + return id +} + +// ID returns the KeyID of the key. +func (k Key) ID() (tkatype.KeyID, error) { + switch k.Kind { + // Because 25519 public keys are so short, we just use the 32-byte + // public as their 'key ID'. + case Key25519: + return tkatype.KeyID(k.Public), nil + default: + return nil, fmt.Errorf("unknown key kind: %v", k.Kind) + } +} + +// Ed25519 returns the ed25519 public key encoded by Key. An error is +// returned for keys which do not represent ed25519 public keys. +func (k Key) Ed25519() (ed25519.PublicKey, error) { + switch k.Kind { + case Key25519: + return ed25519.PublicKey(k.Public), nil + default: + return nil, fmt.Errorf("key is of type %v, not ed25519", k.Kind) + } +} + +const maxMetaBytes = 512 + +func (k Key) StaticValidate() error { + if k.Votes > 4096 { + return fmt.Errorf("excessive key weight: %d > 4096", k.Votes) + } + if k.Votes == 0 { + return errors.New("key votes must be non-zero") + } + + // We have an arbitrary upper limit on the amount + // of metadata that can be associated with a key, so + // people don't start using it as a key-value store and + // causing pathological cases due to the number + size of + // AUMs. + var metaBytes uint + for k, v := range k.Meta { + metaBytes += uint(len(k) + len(v)) + } + if metaBytes > maxMetaBytes { + return fmt.Errorf("key metadata too big (%d > %d)", metaBytes, maxMetaBytes) + } + + switch k.Kind { + case Key25519: + default: + return fmt.Errorf("unrecognized key kind: %v", k.Kind) + } + return nil +} + +// Verify returns a nil error if the signature is valid over the +// provided AUM BLAKE2s digest, using the given key. +func signatureVerify(s *tkatype.Signature, aumDigest tkatype.AUMSigHash, key Key) error { + // NOTE(tom): Even if we can compute the public from the KeyID, + // its possible for the KeyID to be attacker-controlled + // so we should use the public contained in the state machine. + switch key.Kind { + case Key25519: + if len(key.Public) != ed25519.PublicKeySize { + return fmt.Errorf("ed25519 key has wrong length: %d", len(key.Public)) + } + if ed25519consensus.Verify(ed25519.PublicKey(key.Public), aumDigest[:], s.Signature) { + return nil + } + return errors.New("invalid signature") + + default: + return fmt.Errorf("unhandled key type: %v", key.Kind) + } +} diff --git a/tka/key_test.go b/tka/key_test.go index aaddb2f40..e912f89c4 100644 --- a/tka/key_test.go +++ b/tka/key_test.go @@ -1,97 +1,97 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "crypto/ed25519" - "encoding/binary" - "math/rand" - "testing" - - "tailscale.com/types/key" - "tailscale.com/types/tkatype" -) - -// returns a random source based on the test name + extraSeed. -func testingRand(t *testing.T, extraSeed int64) *rand.Rand { - var seed int64 - if err := binary.Read(bytes.NewBuffer([]byte(t.Name())), binary.LittleEndian, &seed); err != nil { - panic(err) - } - return rand.New(rand.NewSource(seed + extraSeed)) -} - -// generates a 25519 private key based on the seed + test name. -func testingKey25519(t *testing.T, seed int64) (ed25519.PublicKey, ed25519.PrivateKey) { - pub, priv, err := ed25519.GenerateKey(testingRand(t, seed)) - if err != nil { - panic(err) - } - return pub, priv -} - -func TestVerify25519(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{ - Kind: Key25519, - Public: pub, - } - - aum := AUM{ - MessageKind: AUMRemoveKey, - KeyID: []byte{1, 2, 3, 4}, - // Signatures is set to crap so we are sure its ignored in the sigHash computation. - Signatures: []tkatype.Signature{{KeyID: []byte{45, 42}}}, - } - sigHash := aum.SigHash() - aum.Signatures = []tkatype.Signature{ - { - KeyID: key.MustID(), - Signature: ed25519.Sign(priv, sigHash[:]), - }, - } - - if err := signatureVerify(&aum.Signatures[0], aum.SigHash(), key); err != nil { - t.Errorf("signature verification failed: %v", err) - } - - // Make sure it fails with a different public key. - pub2, _ := testingKey25519(t, 2) - key2 := Key{Kind: Key25519, Public: pub2} - if err := signatureVerify(&aum.Signatures[0], aum.SigHash(), key2); err == nil { - t.Error("signature verification with different key did not fail") - } -} - -func TestNLPrivate(t *testing.T) { - p := key.NewNLPrivate() - pub := p.Public() - - // Test that key.NLPrivate implements Signer by making a new - // authority. - k := Key{Kind: Key25519, Public: pub.Verifier(), Votes: 1} - _, aum, err := Create(&Mem{}, State{ - Keys: []Key{k}, - DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, - }, p) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - // Make sure the generated genesis AUM was signed. - if got, want := len(aum.Signatures), 1; got != want { - t.Fatalf("len(signatures) = %d, want %d", got, want) - } - sigHash := aum.SigHash() - if ok := ed25519.Verify(pub.Verifier(), sigHash[:], aum.Signatures[0].Signature); !ok { - t.Error("signature did not verify") - } - - // We manually compute the keyID, so make sure its consistent with - // tka.Key.ID(). - if !bytes.Equal(k.MustID(), p.KeyID()) { - t.Errorf("private.KeyID() & tka KeyID differ: %x != %x", k.MustID(), p.KeyID()) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "crypto/ed25519" + "encoding/binary" + "math/rand" + "testing" + + "tailscale.com/types/key" + "tailscale.com/types/tkatype" +) + +// returns a random source based on the test name + extraSeed. +func testingRand(t *testing.T, extraSeed int64) *rand.Rand { + var seed int64 + if err := binary.Read(bytes.NewBuffer([]byte(t.Name())), binary.LittleEndian, &seed); err != nil { + panic(err) + } + return rand.New(rand.NewSource(seed + extraSeed)) +} + +// generates a 25519 private key based on the seed + test name. +func testingKey25519(t *testing.T, seed int64) (ed25519.PublicKey, ed25519.PrivateKey) { + pub, priv, err := ed25519.GenerateKey(testingRand(t, seed)) + if err != nil { + panic(err) + } + return pub, priv +} + +func TestVerify25519(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{ + Kind: Key25519, + Public: pub, + } + + aum := AUM{ + MessageKind: AUMRemoveKey, + KeyID: []byte{1, 2, 3, 4}, + // Signatures is set to crap so we are sure its ignored in the sigHash computation. + Signatures: []tkatype.Signature{{KeyID: []byte{45, 42}}}, + } + sigHash := aum.SigHash() + aum.Signatures = []tkatype.Signature{ + { + KeyID: key.MustID(), + Signature: ed25519.Sign(priv, sigHash[:]), + }, + } + + if err := signatureVerify(&aum.Signatures[0], aum.SigHash(), key); err != nil { + t.Errorf("signature verification failed: %v", err) + } + + // Make sure it fails with a different public key. + pub2, _ := testingKey25519(t, 2) + key2 := Key{Kind: Key25519, Public: pub2} + if err := signatureVerify(&aum.Signatures[0], aum.SigHash(), key2); err == nil { + t.Error("signature verification with different key did not fail") + } +} + +func TestNLPrivate(t *testing.T) { + p := key.NewNLPrivate() + pub := p.Public() + + // Test that key.NLPrivate implements Signer by making a new + // authority. + k := Key{Kind: Key25519, Public: pub.Verifier(), Votes: 1} + _, aum, err := Create(&Mem{}, State{ + Keys: []Key{k}, + DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, + }, p) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + // Make sure the generated genesis AUM was signed. + if got, want := len(aum.Signatures), 1; got != want { + t.Fatalf("len(signatures) = %d, want %d", got, want) + } + sigHash := aum.SigHash() + if ok := ed25519.Verify(pub.Verifier(), sigHash[:], aum.Signatures[0].Signature); !ok { + t.Error("signature did not verify") + } + + // We manually compute the keyID, so make sure its consistent with + // tka.Key.ID(). + if !bytes.Equal(k.MustID(), p.KeyID()) { + t.Errorf("private.KeyID() & tka KeyID differ: %x != %x", k.MustID(), p.KeyID()) + } +} diff --git a/tka/state.go b/tka/state.go index e99b731cc..0a459bd9a 100644 --- a/tka/state.go +++ b/tka/state.go @@ -1,315 +1,315 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "errors" - "fmt" - - "golang.org/x/crypto/argon2" - "tailscale.com/types/tkatype" -) - -// ErrNoSuchKey is returned if the key referenced by a KeyID does not exist. -var ErrNoSuchKey = errors.New("key not found") - -// State describes Tailnet Key Authority state at an instant in time. -// -// State is mutated by applying Authority Update Messages (AUMs), resulting -// in a new State. -type State struct { - // LastAUMHash is the blake2s digest of the last-applied AUM. - // Because AUMs are strictly ordered and form a hash chain, we - // check the previous AUM hash in an update we are applying - // is the same as the LastAUMHash. - LastAUMHash *AUMHash `cbor:"1,keyasint"` - - // DisablementSecrets are KDF-derived values which can be used - // to turn off the TKA in the event of a consensus-breaking bug. - DisablementSecrets [][]byte `cbor:"2,keyasint"` - - // Keys are the public keys of either: - // - // 1. The signing nodes currently trusted by the TKA. - // 2. Ephemeral keys that were used to generate pre-signed auth keys. - Keys []Key `cbor:"3,keyasint"` - - // StateID's are nonce's, generated on enablement and fixed for - // the lifetime of the Tailnet Key Authority. We generate 16-bytes - // worth of keyspace here just in case we come up with a cool future - // use for this. - StateID1 uint64 `cbor:"4,keyasint,omitempty"` - StateID2 uint64 `cbor:"5,keyasint,omitempty"` -} - -// GetKey returns the trusted key with the specified KeyID. -func (s State) GetKey(key tkatype.KeyID) (Key, error) { - for _, k := range s.Keys { - keyID, err := k.ID() - if err != nil { - return Key{}, err - } - - if bytes.Equal(keyID, key) { - return k, nil - } - } - - return Key{}, ErrNoSuchKey -} - -// Clone makes an independent copy of State. -// -// NOTE: There is a difference between a nil slice and an empty -// slice for encoding purposes, so an implementation of Clone() -// must take care to preserve this. -func (s State) Clone() State { - out := State{ - StateID1: s.StateID1, - StateID2: s.StateID2, - } - - if s.LastAUMHash != nil { - dupe := *s.LastAUMHash - out.LastAUMHash = &dupe - } - - if s.DisablementSecrets != nil { - out.DisablementSecrets = make([][]byte, len(s.DisablementSecrets)) - for i := range s.DisablementSecrets { - out.DisablementSecrets[i] = make([]byte, len(s.DisablementSecrets[i])) - copy(out.DisablementSecrets[i], s.DisablementSecrets[i]) - } - } - - if s.Keys != nil { - out.Keys = make([]Key, len(s.Keys)) - for i := range s.Keys { - out.Keys[i] = s.Keys[i].Clone() - } - } - - return out -} - -// cloneForUpdate is like Clone, except LastAUMHash is set based -// on the hash of the given update. -func (s State) cloneForUpdate(update *AUM) State { - out := s.Clone() - aumHash := update.Hash() - out.LastAUMHash = &aumHash - return out -} - -const disablementLength = 32 - -var disablementSalt = []byte("tailscale network-lock disablement salt") - -// DisablementKDF computes a public value which can be stored in a -// key authority, but cannot be reversed to find the input secret. -// -// When the output of this function is stored in tka state (i.e. in -// tka.State.DisablementSecrets) a call to Authority.ValidDisablement() -// with the input of this function as the argument will return true. -func DisablementKDF(secret []byte) []byte { - // time = 4 (3 recommended, booped to 4 to compensate for less memory) - // memory = 16 (32 recommended) - // threads = 4 - // keyLen = 32 (256 bits) - return argon2.Key(secret, disablementSalt, 4, 16*1024, 4, disablementLength) -} - -// checkDisablement returns true for a valid disablement secret. -func (s State) checkDisablement(secret []byte) bool { - derived := DisablementKDF(secret) - for _, candidate := range s.DisablementSecrets { - if bytes.Equal(derived, candidate) { - return true - } - } - return false -} - -// parentMatches returns true if an AUM can chain to (be applied) -// to the current state. -// -// Specifically, the rules are: -// - The last AUM hash must match (transitively, this implies that this -// update follows the last update message applied to the state machine) -// - Or, the state machine knows no parent (its brand new). -func (s State) parentMatches(update AUM) bool { - if s.LastAUMHash == nil { - return true - } - return bytes.Equal(s.LastAUMHash[:], update.PrevAUMHash) -} - -// applyVerifiedAUM computes a new state based on the update provided. -// -// The provided update MUST be verified: That is, the AUM must be well-formed -// (as defined by StaticValidate()), and signatures over the AUM must have -// been verified. -func (s State) applyVerifiedAUM(update AUM) (State, error) { - // Validate that the update message has the right parent. - if !s.parentMatches(update) { - return State{}, errors.New("parent AUMHash mismatch") - } - - switch update.MessageKind { - case AUMNoOp: - out := s.cloneForUpdate(&update) - return out, nil - - case AUMCheckpoint: - if update.State == nil { - return State{}, errors.New("missing checkpoint state") - } - id1Match, id2Match := update.State.StateID1 == s.StateID1, update.State.StateID2 == s.StateID2 - if !id1Match || !id2Match { - return State{}, errors.New("checkpointed state has an incorrect stateID") - } - return update.State.cloneForUpdate(&update), nil - - case AUMAddKey: - if update.Key == nil { - return State{}, errors.New("no key to add provided") - } - keyID, err := update.Key.ID() - if err != nil { - return State{}, err - } - if _, err := s.GetKey(keyID); err == nil { - return State{}, errors.New("key already exists") - } - out := s.cloneForUpdate(&update) - out.Keys = append(out.Keys, *update.Key) - return out, nil - - case AUMUpdateKey: - k, err := s.GetKey(update.KeyID) - if err != nil { - return State{}, err - } - if update.Votes != nil { - k.Votes = *update.Votes - } - if update.Meta != nil { - k.Meta = update.Meta - } - if err := k.StaticValidate(); err != nil { - return State{}, fmt.Errorf("updated key fails validation: %v", err) - } - out := s.cloneForUpdate(&update) - for i := range out.Keys { - keyID, err := out.Keys[i].ID() - if err != nil { - return State{}, err - } - if bytes.Equal(keyID, update.KeyID) { - out.Keys[i] = k - } - } - return out, nil - - case AUMRemoveKey: - idx := -1 - for i := range s.Keys { - keyID, err := s.Keys[i].ID() - if err != nil { - return State{}, err - } - if bytes.Equal(update.KeyID, keyID) { - idx = i - break - } - } - if idx < 0 { - return State{}, ErrNoSuchKey - } - out := s.cloneForUpdate(&update) - out.Keys = append(out.Keys[:idx], out.Keys[idx+1:]...) - return out, nil - - default: - // An AUM with an unknown message kind was received! That means - // that a future version of tailscaled added some feature we don't - // understand. - // - // The future-compatibility contract for AUM message types is that - // they must only add new features, not change the semantics of existing - // mechanisms or features. As such, old clients can safely ignore them. - out := s.cloneForUpdate(&update) - return out, nil - } -} - -// Upper bound on checkpoint elements, chosen arbitrarily. Intended to -// cap out insanely large AUMs. -const ( - maxDisablementSecrets = 32 - maxKeys = 512 -) - -// staticValidateCheckpoint validates that the state is well-formed for -// inclusion in a checkpoint AUM. -func (s *State) staticValidateCheckpoint() error { - if s.LastAUMHash != nil { - return errors.New("cannot specify a parent AUM") - } - if len(s.DisablementSecrets) == 0 { - return errors.New("at least one disablement secret required") - } - if numDS := len(s.DisablementSecrets); numDS > maxDisablementSecrets { - return fmt.Errorf("too many disablement secrets (%d, max %d)", numDS, maxDisablementSecrets) - } - for i, ds := range s.DisablementSecrets { - if len(ds) != disablementLength { - return fmt.Errorf("disablement[%d]: invalid length (got %d, want %d)", i, len(ds), disablementLength) - } - for j, ds2 := range s.DisablementSecrets { - if i == j { - continue - } - if bytes.Equal(ds, ds2) { - return fmt.Errorf("disablement[%d]: duplicates disablement[%d]", i, j) - } - } - } - - if len(s.Keys) == 0 { - return errors.New("at least one key is required") - } - if numKeys := len(s.Keys); numKeys > maxKeys { - return fmt.Errorf("too many keys (%d, max %d)", numKeys, maxKeys) - } - for i, k := range s.Keys { - if err := k.StaticValidate(); err != nil { - return fmt.Errorf("key[%d]: %v", i, err) - } - } - // NOTE: The max number of keys is constrained (512), so - // O(n^2) is fine. - for i, k := range s.Keys { - for j, k2 := range s.Keys { - if i == j { - continue - } - - id1, err := k.ID() - if err != nil { - return fmt.Errorf("key[%d]: %w", i, err) - } - id2, err := k2.ID() - if err != nil { - return fmt.Errorf("key[%d]: %w", j, err) - } - - if bytes.Equal(id1, id2) { - return fmt.Errorf("key[%d]: duplicates key[%d]", i, j) - } - } - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "errors" + "fmt" + + "golang.org/x/crypto/argon2" + "tailscale.com/types/tkatype" +) + +// ErrNoSuchKey is returned if the key referenced by a KeyID does not exist. +var ErrNoSuchKey = errors.New("key not found") + +// State describes Tailnet Key Authority state at an instant in time. +// +// State is mutated by applying Authority Update Messages (AUMs), resulting +// in a new State. +type State struct { + // LastAUMHash is the blake2s digest of the last-applied AUM. + // Because AUMs are strictly ordered and form a hash chain, we + // check the previous AUM hash in an update we are applying + // is the same as the LastAUMHash. + LastAUMHash *AUMHash `cbor:"1,keyasint"` + + // DisablementSecrets are KDF-derived values which can be used + // to turn off the TKA in the event of a consensus-breaking bug. + DisablementSecrets [][]byte `cbor:"2,keyasint"` + + // Keys are the public keys of either: + // + // 1. The signing nodes currently trusted by the TKA. + // 2. Ephemeral keys that were used to generate pre-signed auth keys. + Keys []Key `cbor:"3,keyasint"` + + // StateID's are nonce's, generated on enablement and fixed for + // the lifetime of the Tailnet Key Authority. We generate 16-bytes + // worth of keyspace here just in case we come up with a cool future + // use for this. + StateID1 uint64 `cbor:"4,keyasint,omitempty"` + StateID2 uint64 `cbor:"5,keyasint,omitempty"` +} + +// GetKey returns the trusted key with the specified KeyID. +func (s State) GetKey(key tkatype.KeyID) (Key, error) { + for _, k := range s.Keys { + keyID, err := k.ID() + if err != nil { + return Key{}, err + } + + if bytes.Equal(keyID, key) { + return k, nil + } + } + + return Key{}, ErrNoSuchKey +} + +// Clone makes an independent copy of State. +// +// NOTE: There is a difference between a nil slice and an empty +// slice for encoding purposes, so an implementation of Clone() +// must take care to preserve this. +func (s State) Clone() State { + out := State{ + StateID1: s.StateID1, + StateID2: s.StateID2, + } + + if s.LastAUMHash != nil { + dupe := *s.LastAUMHash + out.LastAUMHash = &dupe + } + + if s.DisablementSecrets != nil { + out.DisablementSecrets = make([][]byte, len(s.DisablementSecrets)) + for i := range s.DisablementSecrets { + out.DisablementSecrets[i] = make([]byte, len(s.DisablementSecrets[i])) + copy(out.DisablementSecrets[i], s.DisablementSecrets[i]) + } + } + + if s.Keys != nil { + out.Keys = make([]Key, len(s.Keys)) + for i := range s.Keys { + out.Keys[i] = s.Keys[i].Clone() + } + } + + return out +} + +// cloneForUpdate is like Clone, except LastAUMHash is set based +// on the hash of the given update. +func (s State) cloneForUpdate(update *AUM) State { + out := s.Clone() + aumHash := update.Hash() + out.LastAUMHash = &aumHash + return out +} + +const disablementLength = 32 + +var disablementSalt = []byte("tailscale network-lock disablement salt") + +// DisablementKDF computes a public value which can be stored in a +// key authority, but cannot be reversed to find the input secret. +// +// When the output of this function is stored in tka state (i.e. in +// tka.State.DisablementSecrets) a call to Authority.ValidDisablement() +// with the input of this function as the argument will return true. +func DisablementKDF(secret []byte) []byte { + // time = 4 (3 recommended, booped to 4 to compensate for less memory) + // memory = 16 (32 recommended) + // threads = 4 + // keyLen = 32 (256 bits) + return argon2.Key(secret, disablementSalt, 4, 16*1024, 4, disablementLength) +} + +// checkDisablement returns true for a valid disablement secret. +func (s State) checkDisablement(secret []byte) bool { + derived := DisablementKDF(secret) + for _, candidate := range s.DisablementSecrets { + if bytes.Equal(derived, candidate) { + return true + } + } + return false +} + +// parentMatches returns true if an AUM can chain to (be applied) +// to the current state. +// +// Specifically, the rules are: +// - The last AUM hash must match (transitively, this implies that this +// update follows the last update message applied to the state machine) +// - Or, the state machine knows no parent (its brand new). +func (s State) parentMatches(update AUM) bool { + if s.LastAUMHash == nil { + return true + } + return bytes.Equal(s.LastAUMHash[:], update.PrevAUMHash) +} + +// applyVerifiedAUM computes a new state based on the update provided. +// +// The provided update MUST be verified: That is, the AUM must be well-formed +// (as defined by StaticValidate()), and signatures over the AUM must have +// been verified. +func (s State) applyVerifiedAUM(update AUM) (State, error) { + // Validate that the update message has the right parent. + if !s.parentMatches(update) { + return State{}, errors.New("parent AUMHash mismatch") + } + + switch update.MessageKind { + case AUMNoOp: + out := s.cloneForUpdate(&update) + return out, nil + + case AUMCheckpoint: + if update.State == nil { + return State{}, errors.New("missing checkpoint state") + } + id1Match, id2Match := update.State.StateID1 == s.StateID1, update.State.StateID2 == s.StateID2 + if !id1Match || !id2Match { + return State{}, errors.New("checkpointed state has an incorrect stateID") + } + return update.State.cloneForUpdate(&update), nil + + case AUMAddKey: + if update.Key == nil { + return State{}, errors.New("no key to add provided") + } + keyID, err := update.Key.ID() + if err != nil { + return State{}, err + } + if _, err := s.GetKey(keyID); err == nil { + return State{}, errors.New("key already exists") + } + out := s.cloneForUpdate(&update) + out.Keys = append(out.Keys, *update.Key) + return out, nil + + case AUMUpdateKey: + k, err := s.GetKey(update.KeyID) + if err != nil { + return State{}, err + } + if update.Votes != nil { + k.Votes = *update.Votes + } + if update.Meta != nil { + k.Meta = update.Meta + } + if err := k.StaticValidate(); err != nil { + return State{}, fmt.Errorf("updated key fails validation: %v", err) + } + out := s.cloneForUpdate(&update) + for i := range out.Keys { + keyID, err := out.Keys[i].ID() + if err != nil { + return State{}, err + } + if bytes.Equal(keyID, update.KeyID) { + out.Keys[i] = k + } + } + return out, nil + + case AUMRemoveKey: + idx := -1 + for i := range s.Keys { + keyID, err := s.Keys[i].ID() + if err != nil { + return State{}, err + } + if bytes.Equal(update.KeyID, keyID) { + idx = i + break + } + } + if idx < 0 { + return State{}, ErrNoSuchKey + } + out := s.cloneForUpdate(&update) + out.Keys = append(out.Keys[:idx], out.Keys[idx+1:]...) + return out, nil + + default: + // An AUM with an unknown message kind was received! That means + // that a future version of tailscaled added some feature we don't + // understand. + // + // The future-compatibility contract for AUM message types is that + // they must only add new features, not change the semantics of existing + // mechanisms or features. As such, old clients can safely ignore them. + out := s.cloneForUpdate(&update) + return out, nil + } +} + +// Upper bound on checkpoint elements, chosen arbitrarily. Intended to +// cap out insanely large AUMs. +const ( + maxDisablementSecrets = 32 + maxKeys = 512 +) + +// staticValidateCheckpoint validates that the state is well-formed for +// inclusion in a checkpoint AUM. +func (s *State) staticValidateCheckpoint() error { + if s.LastAUMHash != nil { + return errors.New("cannot specify a parent AUM") + } + if len(s.DisablementSecrets) == 0 { + return errors.New("at least one disablement secret required") + } + if numDS := len(s.DisablementSecrets); numDS > maxDisablementSecrets { + return fmt.Errorf("too many disablement secrets (%d, max %d)", numDS, maxDisablementSecrets) + } + for i, ds := range s.DisablementSecrets { + if len(ds) != disablementLength { + return fmt.Errorf("disablement[%d]: invalid length (got %d, want %d)", i, len(ds), disablementLength) + } + for j, ds2 := range s.DisablementSecrets { + if i == j { + continue + } + if bytes.Equal(ds, ds2) { + return fmt.Errorf("disablement[%d]: duplicates disablement[%d]", i, j) + } + } + } + + if len(s.Keys) == 0 { + return errors.New("at least one key is required") + } + if numKeys := len(s.Keys); numKeys > maxKeys { + return fmt.Errorf("too many keys (%d, max %d)", numKeys, maxKeys) + } + for i, k := range s.Keys { + if err := k.StaticValidate(); err != nil { + return fmt.Errorf("key[%d]: %v", i, err) + } + } + // NOTE: The max number of keys is constrained (512), so + // O(n^2) is fine. + for i, k := range s.Keys { + for j, k2 := range s.Keys { + if i == j { + continue + } + + id1, err := k.ID() + if err != nil { + return fmt.Errorf("key[%d]: %w", i, err) + } + id2, err := k2.ID() + if err != nil { + return fmt.Errorf("key[%d]: %w", j, err) + } + + if bytes.Equal(id1, id2) { + return fmt.Errorf("key[%d]: duplicates key[%d]", i, j) + } + } + } + return nil +} diff --git a/tka/state_test.go b/tka/state_test.go index b8337dd8a..060bd9350 100644 --- a/tka/state_test.go +++ b/tka/state_test.go @@ -1,260 +1,260 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "encoding/hex" - "errors" - "testing" - - "github.com/fxamacker/cbor/v2" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" -) - -func fromHex(in string) []byte { - out, err := hex.DecodeString(in) - if err != nil { - panic(err) - } - return out -} - -func hashFromHex(in string) *AUMHash { - var out AUMHash - copy(out[:], fromHex(in)) - return &out -} - -func TestCloneState(t *testing.T) { - tcs := []struct { - Name string - State State - }{ - { - "Empty", - State{}, - }, - { - "Key", - State{ - Keys: []Key{{Kind: Key25519, Votes: 2, Public: []byte{5, 6, 7, 8}, Meta: map[string]string{"a": "b"}}}, - }, - }, - { - "StateID", - State{ - StateID1: 42, - StateID2: 22, - }, - }, - { - "DisablementSecrets", - State{ - DisablementSecrets: [][]byte{ - {1, 2, 3, 4}, - {5, 6, 7, 8}, - }, - }, - }, - } - - for _, tc := range tcs { - t.Run(tc.Name, func(t *testing.T) { - if diff := cmp.Diff(tc.State, tc.State.Clone()); diff != "" { - t.Errorf("output state differs (-want, +got):\n%s", diff) - } - - // Make sure the cloned State is the same even after - // an encode + decode into + from CBOR. - t.Run("cbor", func(t *testing.T) { - out := bytes.NewBuffer(nil) - encoder, err := cbor.CTAP2EncOptions().EncMode() - if err != nil { - t.Fatal(err) - } - if err := encoder.NewEncoder(out).Encode(tc.State.Clone()); err != nil { - t.Fatal(err) - } - - var decodedState State - if err := cbor.Unmarshal(out.Bytes(), &decodedState); err != nil { - t.Fatalf("Unmarshal failed: %v", err) - } - if diff := cmp.Diff(tc.State, decodedState); diff != "" { - t.Errorf("decoded state differs (-want, +got):\n%s", diff) - } - }) - }) - } -} - -func TestApplyUpdatesChain(t *testing.T) { - intOne := uint(1) - tcs := []struct { - Name string - Updates []AUM - Start State - End State - }{ - { - "AddKey", - []AUM{{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, - State{}, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"), - }, - }, - { - "RemoveKey", - []AUM{{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03")}}, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"), - }, - State{ - LastAUMHash: hashFromHex("15d65756abfafbb592279503f40759898590c9c59056be1e2e9f02684c15ba4b"), - }, - }, - { - "UpdateKey", - []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1, 2, 3, 4}, Votes: &intOne, Meta: map[string]string{"a": "b"}, PrevAUMHash: fromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03")}}, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"), - }, - State{ - LastAUMHash: hashFromHex("d55458a9c3ed6997439ba5a18b9b62d2c6e5e0c1bb4c61409e92a1281a3b458d"), - Keys: []Key{{Kind: Key25519, Votes: 1, Meta: map[string]string{"a": "b"}, Public: []byte{1, 2, 3, 4}}}, - }, - }, - { - "ChainedKeyUpdates", - []AUM{ - {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, - {MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("f09bda3bb7cf6756ea9adc25770aede4b3ca8142949d6ef5ca0add29af912fd4")}, - }, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - }, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, - LastAUMHash: hashFromHex("218165fe5f757304b9deaff4ac742890364f5f509e533c74e80e0ce35e44ee1d"), - }, - }, - { - "Checkpoint", - []AUM{ - {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, - {MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - }, PrevAUMHash: fromHex("f09bda3bb7cf6756ea9adc25770aede4b3ca8142949d6ef5ca0add29af912fd4")}, - }, - State{DisablementSecrets: [][]byte{{1, 2, 3, 4}}}, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - LastAUMHash: hashFromHex("57343671da5eea3cfb502954e976e8028bffd3540b50a043b2a65a8d8d8217d0"), - }, - }, - } - - for _, tc := range tcs { - t.Run(tc.Name, func(t *testing.T) { - state := tc.Start - for i := range tc.Updates { - var err error - // t.Logf("update[%d] start-state = %+v", i, state) - state, err = state.applyVerifiedAUM(tc.Updates[i]) - if err != nil { - t.Fatalf("Apply message[%d] failed: %v", i, err) - } - // t.Logf("update[%d] end-state = %+v", i, state) - - updateHash := tc.Updates[i].Hash() - if got, want := *state.LastAUMHash, updateHash[:]; !bytes.Equal(got[:], want) { - t.Errorf("expected state.LastAUMHash = %x (update %d), got %x", want, i, got) - } - } - - if diff := cmp.Diff(tc.End, state, cmpopts.EquateEmpty()); diff != "" { - t.Errorf("output state differs (+got, -want):\n%s", diff) - } - }) - } -} - -func TestApplyUpdateErrors(t *testing.T) { - tooLargeVotes := uint(99999) - tcs := []struct { - Name string - Updates []AUM - Start State - Error error - }{ - { - "AddKey exists", - []AUM{{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, - State{Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, - errors.New("key already exists"), - }, - { - "RemoveKey notfound", - []AUM{{MessageKind: AUMRemoveKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, - State{}, - ErrNoSuchKey, - }, - { - "UpdateKey notfound", - []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1}}}, - State{}, - ErrNoSuchKey, - }, - { - "UpdateKey now fails validation", - []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1}, Votes: &tooLargeVotes}}, - State{Keys: []Key{{Kind: Key25519, Public: []byte{1}}}}, - errors.New("updated key fails validation: excessive key weight: 99999 > 4096"), - }, - { - "Bad lastAUMHash", - []AUM{ - {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, - {MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("1234")}, - }, - State{ - Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, - }, - errors.New("parent AUMHash mismatch"), - }, - { - "Bad StateID", - []AUM{{MessageKind: AUMCheckpoint, State: &State{StateID1: 1}}}, - State{Keys: []Key{{Kind: Key25519, Public: []byte{1}}}, StateID1: 42}, - errors.New("checkpointed state has an incorrect stateID"), - }, - } - - for _, tc := range tcs { - t.Run(tc.Name, func(t *testing.T) { - state := tc.Start - for i := range tc.Updates { - var err error - // t.Logf("update[%d] start-state = %+v", i, state) - state, err = state.applyVerifiedAUM(tc.Updates[i]) - if err != nil { - if err.Error() != tc.Error.Error() { - t.Errorf("state[%d].Err = %v, want %v", i, err, tc.Error) - } else { - return - } - } - // t.Logf("update[%d] end-state = %+v", i, state) - } - - t.Errorf("did not error, expected %v", tc.Error) - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "encoding/hex" + "errors" + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +func fromHex(in string) []byte { + out, err := hex.DecodeString(in) + if err != nil { + panic(err) + } + return out +} + +func hashFromHex(in string) *AUMHash { + var out AUMHash + copy(out[:], fromHex(in)) + return &out +} + +func TestCloneState(t *testing.T) { + tcs := []struct { + Name string + State State + }{ + { + "Empty", + State{}, + }, + { + "Key", + State{ + Keys: []Key{{Kind: Key25519, Votes: 2, Public: []byte{5, 6, 7, 8}, Meta: map[string]string{"a": "b"}}}, + }, + }, + { + "StateID", + State{ + StateID1: 42, + StateID2: 22, + }, + }, + { + "DisablementSecrets", + State{ + DisablementSecrets: [][]byte{ + {1, 2, 3, 4}, + {5, 6, 7, 8}, + }, + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.Name, func(t *testing.T) { + if diff := cmp.Diff(tc.State, tc.State.Clone()); diff != "" { + t.Errorf("output state differs (-want, +got):\n%s", diff) + } + + // Make sure the cloned State is the same even after + // an encode + decode into + from CBOR. + t.Run("cbor", func(t *testing.T) { + out := bytes.NewBuffer(nil) + encoder, err := cbor.CTAP2EncOptions().EncMode() + if err != nil { + t.Fatal(err) + } + if err := encoder.NewEncoder(out).Encode(tc.State.Clone()); err != nil { + t.Fatal(err) + } + + var decodedState State + if err := cbor.Unmarshal(out.Bytes(), &decodedState); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if diff := cmp.Diff(tc.State, decodedState); diff != "" { + t.Errorf("decoded state differs (-want, +got):\n%s", diff) + } + }) + }) + } +} + +func TestApplyUpdatesChain(t *testing.T) { + intOne := uint(1) + tcs := []struct { + Name string + Updates []AUM + Start State + End State + }{ + { + "AddKey", + []AUM{{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, + State{}, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"), + }, + }, + { + "RemoveKey", + []AUM{{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03")}}, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"), + }, + State{ + LastAUMHash: hashFromHex("15d65756abfafbb592279503f40759898590c9c59056be1e2e9f02684c15ba4b"), + }, + }, + { + "UpdateKey", + []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1, 2, 3, 4}, Votes: &intOne, Meta: map[string]string{"a": "b"}, PrevAUMHash: fromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03")}}, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + LastAUMHash: hashFromHex("53898e4311d0b6087fcbb871563868a16c629d9267df851fcfa7b52b31d2bd03"), + }, + State{ + LastAUMHash: hashFromHex("d55458a9c3ed6997439ba5a18b9b62d2c6e5e0c1bb4c61409e92a1281a3b458d"), + Keys: []Key{{Kind: Key25519, Votes: 1, Meta: map[string]string{"a": "b"}, Public: []byte{1, 2, 3, 4}}}, + }, + }, + { + "ChainedKeyUpdates", + []AUM{ + {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, + {MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("f09bda3bb7cf6756ea9adc25770aede4b3ca8142949d6ef5ca0add29af912fd4")}, + }, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + }, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, + LastAUMHash: hashFromHex("218165fe5f757304b9deaff4ac742890364f5f509e533c74e80e0ce35e44ee1d"), + }, + }, + { + "Checkpoint", + []AUM{ + {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, + {MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + }, PrevAUMHash: fromHex("f09bda3bb7cf6756ea9adc25770aede4b3ca8142949d6ef5ca0add29af912fd4")}, + }, + State{DisablementSecrets: [][]byte{{1, 2, 3, 4}}}, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + LastAUMHash: hashFromHex("57343671da5eea3cfb502954e976e8028bffd3540b50a043b2a65a8d8d8217d0"), + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.Name, func(t *testing.T) { + state := tc.Start + for i := range tc.Updates { + var err error + // t.Logf("update[%d] start-state = %+v", i, state) + state, err = state.applyVerifiedAUM(tc.Updates[i]) + if err != nil { + t.Fatalf("Apply message[%d] failed: %v", i, err) + } + // t.Logf("update[%d] end-state = %+v", i, state) + + updateHash := tc.Updates[i].Hash() + if got, want := *state.LastAUMHash, updateHash[:]; !bytes.Equal(got[:], want) { + t.Errorf("expected state.LastAUMHash = %x (update %d), got %x", want, i, got) + } + } + + if diff := cmp.Diff(tc.End, state, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("output state differs (+got, -want):\n%s", diff) + } + }) + } +} + +func TestApplyUpdateErrors(t *testing.T) { + tooLargeVotes := uint(99999) + tcs := []struct { + Name string + Updates []AUM + Start State + Error error + }{ + { + "AddKey exists", + []AUM{{MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, + State{Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, + errors.New("key already exists"), + }, + { + "RemoveKey notfound", + []AUM{{MessageKind: AUMRemoveKey, Key: &Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}}, + State{}, + ErrNoSuchKey, + }, + { + "UpdateKey notfound", + []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1}}}, + State{}, + ErrNoSuchKey, + }, + { + "UpdateKey now fails validation", + []AUM{{MessageKind: AUMUpdateKey, KeyID: []byte{1}, Votes: &tooLargeVotes}}, + State{Keys: []Key{{Kind: Key25519, Public: []byte{1}}}}, + errors.New("updated key fails validation: excessive key weight: 99999 > 4096"), + }, + { + "Bad lastAUMHash", + []AUM{ + {MessageKind: AUMAddKey, Key: &Key{Kind: Key25519, Public: []byte{5, 6, 7, 8}}}, + {MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, PrevAUMHash: fromHex("1234")}, + }, + State{ + Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, + }, + errors.New("parent AUMHash mismatch"), + }, + { + "Bad StateID", + []AUM{{MessageKind: AUMCheckpoint, State: &State{StateID1: 1}}}, + State{Keys: []Key{{Kind: Key25519, Public: []byte{1}}}, StateID1: 42}, + errors.New("checkpointed state has an incorrect stateID"), + }, + } + + for _, tc := range tcs { + t.Run(tc.Name, func(t *testing.T) { + state := tc.Start + for i := range tc.Updates { + var err error + // t.Logf("update[%d] start-state = %+v", i, state) + state, err = state.applyVerifiedAUM(tc.Updates[i]) + if err != nil { + if err.Error() != tc.Error.Error() { + t.Errorf("state[%d].Err = %v, want %v", i, err, tc.Error) + } else { + return + } + } + // t.Logf("update[%d] end-state = %+v", i, state) + } + + t.Errorf("did not error, expected %v", tc.Error) + }) + } +} diff --git a/tka/sync_test.go b/tka/sync_test.go index d214020c4..7250eacf7 100644 --- a/tka/sync_test.go +++ b/tka/sync_test.go @@ -1,377 +1,377 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "strconv" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestSyncOffer(t *testing.T) { - c := newTestchain(t, ` - A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 - A10 -> A11 -> A12 -> A13 -> A14 -> A15 -> A16 -> A17 -> A18 - A18 -> A19 -> A20 -> A21 -> A22 -> A23 -> A24 -> A25 - `) - storage := c.Chonk() - a, err := Open(storage) - if err != nil { - t.Fatal(err) - } - got, err := a.SyncOffer(storage) - if err != nil { - t.Fatal(err) - } - - // A SyncOffer includes a selection of AUMs going backwards in the tree, - // progressively skipping more and more each iteration. - want := SyncOffer{ - Head: c.AUMHashes["A25"], - Ancestors: []AUMHash{ - c.AUMHashes["A"+strconv.Itoa(25-ancestorsSkipStart)], - c.AUMHashes["A"+strconv.Itoa(25-ancestorsSkipStart< A2 - // Node 2 has: A1 -> A2 -> A3 -> A4 - c := newTestchain(t, ` - A1 -> A2 -> A3 -> A4 - `) - a1H, a2H := c.AUMHashes["A1"], c.AUMHashes["A2"] - - chonk1 := c.ChonkWith("A1", "A2") - n1, err := Open(chonk1) - if err != nil { - t.Fatal(err) - } - offer1, err := n1.SyncOffer(chonk1) - if err != nil { - t.Fatal(err) - } - - chonk2 := c.Chonk() // All AUMs - n2, err := Open(chonk2) - if err != nil { - t.Fatal(err) - } - offer2, err := n2.SyncOffer(chonk2) - if err != nil { - t.Fatal(err) - } - - // Node 1 only knows about the first two nodes, so the head of n2 is - // alien to it. - t.Run("n1", func(t *testing.T) { - got, err := computeSyncIntersection(chonk1, offer1, offer2) - if err != nil { - t.Fatalf("computeSyncIntersection() failed: %v", err) - } - want := &intersection{ - tailIntersection: &a1H, - } - if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { - t.Errorf("intersection diff (-want, +got):\n%s", diff) - } - }) - - // Node 2 knows about the full chain, so it can see that the head of n1 - // intersects with a subset of its chain (a Head Intersection). - t.Run("n2", func(t *testing.T) { - got, err := computeSyncIntersection(chonk2, offer2, offer1) - if err != nil { - t.Fatalf("computeSyncIntersection() failed: %v", err) - } - want := &intersection{ - headIntersection: &a2H, - } - if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { - t.Errorf("intersection diff (-want, +got):\n%s", diff) - } - }) -} - -func TestComputeSyncIntersection_ForkSmallDiff(t *testing.T) { - // The number of nodes in the chain is longer than ancestorSkipStart, - // so that during sync both nodes are able to find a common ancestor - // which was later than A1. - - c := newTestchain(t, ` - A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 - | -> F1 - // Make F1 different to A9. - // hashSeed is chosen such that the hash is higher than A9. - F1.hashSeed = 7 - `) - // Node 1 has: A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> F1 - // Node 2 has: A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 - f1H, a9H := c.AUMHashes["F1"], c.AUMHashes["A9"] - - if bytes.Compare(f1H[:], a9H[:]) < 0 { - t.Fatal("failed assert: h(a9) > h(f1H)\nTweak hashSeed till this passes") - } - - chonk1 := c.ChonkWith("A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "F1") - n1, err := Open(chonk1) - if err != nil { - t.Fatal(err) - } - offer1, err := n1.SyncOffer(chonk1) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(SyncOffer{ - Head: c.AUMHashes["F1"], - Ancestors: []AUMHash{ - c.AUMHashes["A"+strconv.Itoa(9-ancestorsSkipStart)], - c.AUMHashes["A1"], - }, - }, offer1); diff != "" { - t.Errorf("offer1 diff (-want, +got):\n%s", diff) - } - - chonk2 := c.ChonkWith("A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9", "A10") - n2, err := Open(chonk2) - if err != nil { - t.Fatal(err) - } - offer2, err := n2.SyncOffer(chonk2) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(SyncOffer{ - Head: c.AUMHashes["A10"], - Ancestors: []AUMHash{ - c.AUMHashes["A"+strconv.Itoa(10-ancestorsSkipStart)], - c.AUMHashes["A1"], - }, - }, offer2); diff != "" { - t.Errorf("offer2 diff (-want, +got):\n%s", diff) - } - - // Node 1 only knows about the first eight nodes, so the head of n2 is - // alien to it. - t.Run("n1", func(t *testing.T) { - // n2 has 10 nodes, so the first common ancestor should be 10-ancestorsSkipStart - wantIntersection := c.AUMHashes["A"+strconv.Itoa(10-ancestorsSkipStart)] - - got, err := computeSyncIntersection(chonk1, offer1, offer2) - if err != nil { - t.Fatalf("computeSyncIntersection() failed: %v", err) - } - want := &intersection{ - tailIntersection: &wantIntersection, - } - if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { - t.Errorf("intersection diff (-want, +got):\n%s", diff) - } - }) - - // Node 2 knows about the full chain but doesn't recognize the head. - t.Run("n2", func(t *testing.T) { - // n1 has 9 nodes, so the first common ancestor should be 9-ancestorsSkipStart - wantIntersection := c.AUMHashes["A"+strconv.Itoa(9-ancestorsSkipStart)] - - got, err := computeSyncIntersection(chonk2, offer2, offer1) - if err != nil { - t.Fatalf("computeSyncIntersection() failed: %v", err) - } - want := &intersection{ - tailIntersection: &wantIntersection, - } - if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { - t.Errorf("intersection diff (-want, +got):\n%s", diff) - } - }) -} - -func TestMissingAUMs_FastForward(t *testing.T) { - // Node 1 has: A1 -> A2 - // Node 2 has: A1 -> A2 -> A3 -> A4 - c := newTestchain(t, ` - A1 -> A2 -> A3 -> A4 - A1.hashSeed = 1 - A2.hashSeed = 2 - A3.hashSeed = 3 - A4.hashSeed = 4 - `) - - chonk1 := c.ChonkWith("A1", "A2") - n1, err := Open(chonk1) - if err != nil { - t.Fatal(err) - } - offer1, err := n1.SyncOffer(chonk1) - if err != nil { - t.Fatal(err) - } - - chonk2 := c.Chonk() // All AUMs - n2, err := Open(chonk2) - if err != nil { - t.Fatal(err) - } - offer2, err := n2.SyncOffer(chonk2) - if err != nil { - t.Fatal(err) - } - - // Node 1 only knows about the first two nodes, so the head of n2 is - // alien to it. As such, it should send history from the newest ancestor, - // A1 (if the chain was longer there would be one in the middle). - t.Run("n1", func(t *testing.T) { - got, err := n1.MissingAUMs(chonk1, offer2) - if err != nil { - t.Fatalf("MissingAUMs() failed: %v", err) - } - - // Both sides have A1, so the only AUM that n2 might not have is - // A2. - want := []AUM{c.AUMs["A2"]} - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) - } - }) - - // Node 2 knows about the full chain, so it can see that the head of n1 - // intersects with a subset of its chain (a Head Intersection). - t.Run("n2", func(t *testing.T) { - got, err := n2.MissingAUMs(chonk2, offer1) - if err != nil { - t.Fatalf("MissingAUMs() failed: %v", err) - } - - want := []AUM{ - c.AUMs["A3"], - c.AUMs["A4"], - } - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) - } - }) -} - -func TestMissingAUMs_Fork(t *testing.T) { - // Node 1 has: A1 -> A2 -> A3 -> F1 - // Node 2 has: A1 -> A2 -> A3 -> A4 - c := newTestchain(t, ` - A1 -> A2 -> A3 -> A4 - | -> F1 - A1.hashSeed = 1 - A2.hashSeed = 2 - A3.hashSeed = 3 - A4.hashSeed = 4 - `) - - chonk1 := c.ChonkWith("A1", "A2", "A3", "F1") - n1, err := Open(chonk1) - if err != nil { - t.Fatal(err) - } - offer1, err := n1.SyncOffer(chonk1) - if err != nil { - t.Fatal(err) - } - - chonk2 := c.ChonkWith("A1", "A2", "A3", "A4") - n2, err := Open(chonk2) - if err != nil { - t.Fatal(err) - } - offer2, err := n2.SyncOffer(chonk2) - if err != nil { - t.Fatal(err) - } - - t.Run("n1", func(t *testing.T) { - got, err := n1.MissingAUMs(chonk1, offer2) - if err != nil { - t.Fatalf("MissingAUMs() failed: %v", err) - } - - // Both sides have A1, so n1 will send everything it knows from - // there to head. - want := []AUM{ - c.AUMs["A2"], - c.AUMs["A3"], - c.AUMs["F1"], - } - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) - } - }) - - t.Run("n2", func(t *testing.T) { - got, err := n2.MissingAUMs(chonk2, offer1) - if err != nil { - t.Fatalf("MissingAUMs() failed: %v", err) - } - - // Both sides have A1, so n2 will send everything it knows from - // there to head. - want := []AUM{ - c.AUMs["A2"], - c.AUMs["A3"], - c.AUMs["A4"], - } - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) - } - }) -} - -func TestSyncSimpleE2E(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G1 -> L1 -> L2 -> L3 - G1.template = genesis - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optKey("key", key, priv), - optSignAllUsing("key")) - - nodeStorage := &Mem{} - node, err := Bootstrap(nodeStorage, c.AUMs["G1"]) - if err != nil { - t.Fatalf("node Bootstrap() failed: %v", err) - } - controlStorage := c.Chonk() - control, err := Open(controlStorage) - if err != nil { - t.Fatalf("control Open() failed: %v", err) - } - - // Control knows the full chain, node only knows the genesis. Lets see - // if they can sync. - nodeOffer, err := node.SyncOffer(nodeStorage) - if err != nil { - t.Fatal(err) - } - controlAUMs, err := control.MissingAUMs(controlStorage, nodeOffer) - if err != nil { - t.Fatalf("control.MissingAUMs(%v) failed: %v", nodeOffer, err) - } - if err := node.Inform(nodeStorage, controlAUMs); err != nil { - t.Fatalf("node.Inform(%v) failed: %v", controlAUMs, err) - } - - if cHash, nHash := control.Head(), node.Head(); cHash != nHash { - t.Errorf("node & control are not synced: c=%x, n=%x", cHash, nHash) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "strconv" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestSyncOffer(t *testing.T) { + c := newTestchain(t, ` + A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 + A10 -> A11 -> A12 -> A13 -> A14 -> A15 -> A16 -> A17 -> A18 + A18 -> A19 -> A20 -> A21 -> A22 -> A23 -> A24 -> A25 + `) + storage := c.Chonk() + a, err := Open(storage) + if err != nil { + t.Fatal(err) + } + got, err := a.SyncOffer(storage) + if err != nil { + t.Fatal(err) + } + + // A SyncOffer includes a selection of AUMs going backwards in the tree, + // progressively skipping more and more each iteration. + want := SyncOffer{ + Head: c.AUMHashes["A25"], + Ancestors: []AUMHash{ + c.AUMHashes["A"+strconv.Itoa(25-ancestorsSkipStart)], + c.AUMHashes["A"+strconv.Itoa(25-ancestorsSkipStart< A2 + // Node 2 has: A1 -> A2 -> A3 -> A4 + c := newTestchain(t, ` + A1 -> A2 -> A3 -> A4 + `) + a1H, a2H := c.AUMHashes["A1"], c.AUMHashes["A2"] + + chonk1 := c.ChonkWith("A1", "A2") + n1, err := Open(chonk1) + if err != nil { + t.Fatal(err) + } + offer1, err := n1.SyncOffer(chonk1) + if err != nil { + t.Fatal(err) + } + + chonk2 := c.Chonk() // All AUMs + n2, err := Open(chonk2) + if err != nil { + t.Fatal(err) + } + offer2, err := n2.SyncOffer(chonk2) + if err != nil { + t.Fatal(err) + } + + // Node 1 only knows about the first two nodes, so the head of n2 is + // alien to it. + t.Run("n1", func(t *testing.T) { + got, err := computeSyncIntersection(chonk1, offer1, offer2) + if err != nil { + t.Fatalf("computeSyncIntersection() failed: %v", err) + } + want := &intersection{ + tailIntersection: &a1H, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { + t.Errorf("intersection diff (-want, +got):\n%s", diff) + } + }) + + // Node 2 knows about the full chain, so it can see that the head of n1 + // intersects with a subset of its chain (a Head Intersection). + t.Run("n2", func(t *testing.T) { + got, err := computeSyncIntersection(chonk2, offer2, offer1) + if err != nil { + t.Fatalf("computeSyncIntersection() failed: %v", err) + } + want := &intersection{ + headIntersection: &a2H, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { + t.Errorf("intersection diff (-want, +got):\n%s", diff) + } + }) +} + +func TestComputeSyncIntersection_ForkSmallDiff(t *testing.T) { + // The number of nodes in the chain is longer than ancestorSkipStart, + // so that during sync both nodes are able to find a common ancestor + // which was later than A1. + + c := newTestchain(t, ` + A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 + | -> F1 + // Make F1 different to A9. + // hashSeed is chosen such that the hash is higher than A9. + F1.hashSeed = 7 + `) + // Node 1 has: A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> F1 + // Node 2 has: A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 + f1H, a9H := c.AUMHashes["F1"], c.AUMHashes["A9"] + + if bytes.Compare(f1H[:], a9H[:]) < 0 { + t.Fatal("failed assert: h(a9) > h(f1H)\nTweak hashSeed till this passes") + } + + chonk1 := c.ChonkWith("A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "F1") + n1, err := Open(chonk1) + if err != nil { + t.Fatal(err) + } + offer1, err := n1.SyncOffer(chonk1) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(SyncOffer{ + Head: c.AUMHashes["F1"], + Ancestors: []AUMHash{ + c.AUMHashes["A"+strconv.Itoa(9-ancestorsSkipStart)], + c.AUMHashes["A1"], + }, + }, offer1); diff != "" { + t.Errorf("offer1 diff (-want, +got):\n%s", diff) + } + + chonk2 := c.ChonkWith("A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9", "A10") + n2, err := Open(chonk2) + if err != nil { + t.Fatal(err) + } + offer2, err := n2.SyncOffer(chonk2) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(SyncOffer{ + Head: c.AUMHashes["A10"], + Ancestors: []AUMHash{ + c.AUMHashes["A"+strconv.Itoa(10-ancestorsSkipStart)], + c.AUMHashes["A1"], + }, + }, offer2); diff != "" { + t.Errorf("offer2 diff (-want, +got):\n%s", diff) + } + + // Node 1 only knows about the first eight nodes, so the head of n2 is + // alien to it. + t.Run("n1", func(t *testing.T) { + // n2 has 10 nodes, so the first common ancestor should be 10-ancestorsSkipStart + wantIntersection := c.AUMHashes["A"+strconv.Itoa(10-ancestorsSkipStart)] + + got, err := computeSyncIntersection(chonk1, offer1, offer2) + if err != nil { + t.Fatalf("computeSyncIntersection() failed: %v", err) + } + want := &intersection{ + tailIntersection: &wantIntersection, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { + t.Errorf("intersection diff (-want, +got):\n%s", diff) + } + }) + + // Node 2 knows about the full chain but doesn't recognize the head. + t.Run("n2", func(t *testing.T) { + // n1 has 9 nodes, so the first common ancestor should be 9-ancestorsSkipStart + wantIntersection := c.AUMHashes["A"+strconv.Itoa(9-ancestorsSkipStart)] + + got, err := computeSyncIntersection(chonk2, offer2, offer1) + if err != nil { + t.Fatalf("computeSyncIntersection() failed: %v", err) + } + want := &intersection{ + tailIntersection: &wantIntersection, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(intersection{})); diff != "" { + t.Errorf("intersection diff (-want, +got):\n%s", diff) + } + }) +} + +func TestMissingAUMs_FastForward(t *testing.T) { + // Node 1 has: A1 -> A2 + // Node 2 has: A1 -> A2 -> A3 -> A4 + c := newTestchain(t, ` + A1 -> A2 -> A3 -> A4 + A1.hashSeed = 1 + A2.hashSeed = 2 + A3.hashSeed = 3 + A4.hashSeed = 4 + `) + + chonk1 := c.ChonkWith("A1", "A2") + n1, err := Open(chonk1) + if err != nil { + t.Fatal(err) + } + offer1, err := n1.SyncOffer(chonk1) + if err != nil { + t.Fatal(err) + } + + chonk2 := c.Chonk() // All AUMs + n2, err := Open(chonk2) + if err != nil { + t.Fatal(err) + } + offer2, err := n2.SyncOffer(chonk2) + if err != nil { + t.Fatal(err) + } + + // Node 1 only knows about the first two nodes, so the head of n2 is + // alien to it. As such, it should send history from the newest ancestor, + // A1 (if the chain was longer there would be one in the middle). + t.Run("n1", func(t *testing.T) { + got, err := n1.MissingAUMs(chonk1, offer2) + if err != nil { + t.Fatalf("MissingAUMs() failed: %v", err) + } + + // Both sides have A1, so the only AUM that n2 might not have is + // A2. + want := []AUM{c.AUMs["A2"]} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) + } + }) + + // Node 2 knows about the full chain, so it can see that the head of n1 + // intersects with a subset of its chain (a Head Intersection). + t.Run("n2", func(t *testing.T) { + got, err := n2.MissingAUMs(chonk2, offer1) + if err != nil { + t.Fatalf("MissingAUMs() failed: %v", err) + } + + want := []AUM{ + c.AUMs["A3"], + c.AUMs["A4"], + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) + } + }) +} + +func TestMissingAUMs_Fork(t *testing.T) { + // Node 1 has: A1 -> A2 -> A3 -> F1 + // Node 2 has: A1 -> A2 -> A3 -> A4 + c := newTestchain(t, ` + A1 -> A2 -> A3 -> A4 + | -> F1 + A1.hashSeed = 1 + A2.hashSeed = 2 + A3.hashSeed = 3 + A4.hashSeed = 4 + `) + + chonk1 := c.ChonkWith("A1", "A2", "A3", "F1") + n1, err := Open(chonk1) + if err != nil { + t.Fatal(err) + } + offer1, err := n1.SyncOffer(chonk1) + if err != nil { + t.Fatal(err) + } + + chonk2 := c.ChonkWith("A1", "A2", "A3", "A4") + n2, err := Open(chonk2) + if err != nil { + t.Fatal(err) + } + offer2, err := n2.SyncOffer(chonk2) + if err != nil { + t.Fatal(err) + } + + t.Run("n1", func(t *testing.T) { + got, err := n1.MissingAUMs(chonk1, offer2) + if err != nil { + t.Fatalf("MissingAUMs() failed: %v", err) + } + + // Both sides have A1, so n1 will send everything it knows from + // there to head. + want := []AUM{ + c.AUMs["A2"], + c.AUMs["A3"], + c.AUMs["F1"], + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) + } + }) + + t.Run("n2", func(t *testing.T) { + got, err := n2.MissingAUMs(chonk2, offer1) + if err != nil { + t.Fatalf("MissingAUMs() failed: %v", err) + } + + // Both sides have A1, so n2 will send everything it knows from + // there to head. + want := []AUM{ + c.AUMs["A2"], + c.AUMs["A3"], + c.AUMs["A4"], + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("MissingAUMs diff (-want, +got):\n%s", diff) + } + }) +} + +func TestSyncSimpleE2E(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> L1 -> L2 -> L3 + G1.template = genesis + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optKey("key", key, priv), + optSignAllUsing("key")) + + nodeStorage := &Mem{} + node, err := Bootstrap(nodeStorage, c.AUMs["G1"]) + if err != nil { + t.Fatalf("node Bootstrap() failed: %v", err) + } + controlStorage := c.Chonk() + control, err := Open(controlStorage) + if err != nil { + t.Fatalf("control Open() failed: %v", err) + } + + // Control knows the full chain, node only knows the genesis. Lets see + // if they can sync. + nodeOffer, err := node.SyncOffer(nodeStorage) + if err != nil { + t.Fatal(err) + } + controlAUMs, err := control.MissingAUMs(controlStorage, nodeOffer) + if err != nil { + t.Fatalf("control.MissingAUMs(%v) failed: %v", nodeOffer, err) + } + if err := node.Inform(nodeStorage, controlAUMs); err != nil { + t.Fatalf("node.Inform(%v) failed: %v", controlAUMs, err) + } + + if cHash, nHash := control.Head(), node.Head(); cHash != nHash { + t.Errorf("node & control are not synced: c=%x, n=%x", cHash, nHash) + } +} diff --git a/tka/tailchonk_test.go b/tka/tailchonk_test.go index 13d989f0c..86d5642a3 100644 --- a/tka/tailchonk_test.go +++ b/tka/tailchonk_test.go @@ -1,693 +1,693 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "fmt" - "os" - "path/filepath" - "sync" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "golang.org/x/crypto/blake2s" -) - -// randHash derives a fake blake2s hash from the test name -// and the given seed. -func randHash(t *testing.T, seed int64) [blake2s.Size]byte { - var out [blake2s.Size]byte - testingRand(t, seed).Read(out[:]) - return out -} - -func TestImplementsChonk(t *testing.T) { - impls := []Chonk{&Mem{}, &FS{}} - t.Logf("chonks: %v", impls) -} - -func TestTailchonk_ChildAUMs(t *testing.T) { - for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { - t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { - parentHash := randHash(t, 1) - data := []AUM{ - { - MessageKind: AUMRemoveKey, - KeyID: []byte{1, 2}, - PrevAUMHash: parentHash[:], - }, - { - MessageKind: AUMRemoveKey, - KeyID: []byte{3, 4}, - PrevAUMHash: parentHash[:], - }, - } - - if err := chonk.CommitVerifiedAUMs(data); err != nil { - t.Fatalf("CommitVerifiedAUMs failed: %v", err) - } - stored, err := chonk.ChildAUMs(parentHash) - if err != nil { - t.Fatalf("ChildAUMs failed: %v", err) - } - if diff := cmp.Diff(data, stored); diff != "" { - t.Errorf("stored AUM differs (-want, +got):\n%s", diff) - } - }) - } -} - -func TestTailchonk_AUMMissing(t *testing.T) { - for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { - t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { - var notExists AUMHash - notExists[:][0] = 42 - if _, err := chonk.AUM(notExists); err != os.ErrNotExist { - t.Errorf("chonk.AUM(notExists).err = %v, want %v", err, os.ErrNotExist) - } - }) - } -} - -func TestTailchonkMem_Orphans(t *testing.T) { - chonk := Mem{} - - parentHash := randHash(t, 1) - orphan := AUM{MessageKind: AUMNoOp} - aums := []AUM{ - orphan, - // A parent is specified, so we shouldnt see it in GetOrphans() - { - MessageKind: AUMRemoveKey, - KeyID: []byte{3, 4}, - PrevAUMHash: parentHash[:], - }, - } - if err := chonk.CommitVerifiedAUMs(aums); err != nil { - t.Fatalf("CommitVerifiedAUMs failed: %v", err) - } - - stored, err := chonk.Orphans() - if err != nil { - t.Fatalf("Orphans failed: %v", err) - } - if diff := cmp.Diff([]AUM{orphan}, stored); diff != "" { - t.Errorf("stored AUM differs (-want, +got):\n%s", diff) - } -} - -func TestTailchonk_ReadChainFromHead(t *testing.T) { - for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { - - t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { - genesis := AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}} - gHash := genesis.Hash() - intermediate := AUM{PrevAUMHash: gHash[:]} - iHash := intermediate.Hash() - leaf := AUM{PrevAUMHash: iHash[:]} - - commitSet := []AUM{ - genesis, - intermediate, - leaf, - } - if err := chonk.CommitVerifiedAUMs(commitSet); err != nil { - t.Fatalf("CommitVerifiedAUMs failed: %v", err) - } - // t.Logf("genesis hash = %X", genesis.Hash()) - // t.Logf("intermediate hash = %X", intermediate.Hash()) - // t.Logf("leaf hash = %X", leaf.Hash()) - - // Read the chain from the leaf backwards. - gotLeafs, err := chonk.Heads() - if err != nil { - t.Fatalf("Heads failed: %v", err) - } - if diff := cmp.Diff([]AUM{leaf}, gotLeafs); diff != "" { - t.Fatalf("leaf AUM differs (-want, +got):\n%s", diff) - } - - parent, _ := gotLeafs[0].Parent() - gotIntermediate, err := chonk.AUM(parent) - if err != nil { - t.Fatalf("AUM() failed: %v", err) - } - if diff := cmp.Diff(intermediate, gotIntermediate); diff != "" { - t.Errorf("intermediate AUM differs (-want, +got):\n%s", diff) - } - - parent, _ = gotIntermediate.Parent() - gotGenesis, err := chonk.AUM(parent) - if err != nil { - t.Fatalf("AUM() failed: %v", err) - } - if diff := cmp.Diff(genesis, gotGenesis); diff != "" { - t.Errorf("genesis AUM differs (-want, +got):\n%s", diff) - } - }) - } -} - -func TestTailchonkFS_Commit(t *testing.T) { - chonk := &FS{base: t.TempDir()} - parentHash := randHash(t, 1) - aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} - - if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { - t.Fatal(err) - } - - dir, base := chonk.aumDir(aum.Hash()) - if got, want := dir, filepath.Join(chonk.base, "PD"); got != want { - t.Errorf("aum dir=%s, want %s", got, want) - } - if want := "PD57DVP6GKC76OOZMXFFZUSOEFQXOLAVT7N2ZM5KB3HDIMCANF4A"; base != want { - t.Errorf("aum base=%s, want %s", base, want) - } - if _, err := os.Stat(filepath.Join(dir, base)); err != nil { - t.Errorf("stat of AUM file failed: %v", err) - } - if _, err := os.Stat(filepath.Join(chonk.base, "M7", "M7LL2NDB4NKCZIUPVS6RDM2GUOIMW6EEAFVBWMVCPUANQJPHT3SQ")); err != nil { - t.Errorf("stat of AUM parent failed: %v", err) - } - - info, err := chonk.get(aum.Hash()) - if err != nil { - t.Fatal(err) - } - if info.PurgedUnix > 0 { - t.Errorf("recently-created AUM PurgedUnix = %d, want 0", info.PurgedUnix) - } -} - -func TestTailchonkFS_CommitTime(t *testing.T) { - chonk := &FS{base: t.TempDir()} - parentHash := randHash(t, 1) - aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} - - if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { - t.Fatal(err) - } - ct, err := chonk.CommitTime(aum.Hash()) - if err != nil { - t.Fatalf("CommitTime() failed: %v", err) - } - if ct.Before(time.Now().Add(-time.Minute)) || ct.After(time.Now().Add(time.Minute)) { - t.Errorf("commit time was wrong: %v more than a minute off from now (%v)", ct, time.Now()) - } -} - -func TestTailchonkFS_PurgeAUMs(t *testing.T) { - chonk := &FS{base: t.TempDir()} - parentHash := randHash(t, 1) - aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} - - if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { - t.Fatal(err) - } - if err := chonk.PurgeAUMs([]AUMHash{aum.Hash()}); err != nil { - t.Fatal(err) - } - - if _, err := chonk.AUM(aum.Hash()); err != os.ErrNotExist { - t.Errorf("AUM() on purged AUM returned err = %v, want ErrNotExist", err) - } - - info, err := chonk.get(aum.Hash()) - if err != nil { - t.Fatal(err) - } - if info.PurgedUnix == 0 { - t.Errorf("recently-created AUM PurgedUnix = %d, want non-zero", info.PurgedUnix) - } -} - -func TestTailchonkFS_AllAUMs(t *testing.T) { - chonk := &FS{base: t.TempDir()} - genesis := AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}} - gHash := genesis.Hash() - intermediate := AUM{PrevAUMHash: gHash[:]} - iHash := intermediate.Hash() - leaf := AUM{PrevAUMHash: iHash[:]} - - commitSet := []AUM{ - genesis, - intermediate, - leaf, - } - if err := chonk.CommitVerifiedAUMs(commitSet); err != nil { - t.Fatalf("CommitVerifiedAUMs failed: %v", err) - } - - hashes, err := chonk.AllAUMs() - if err != nil { - t.Fatal(err) - } - hashesLess := func(a, b AUMHash) bool { - return bytes.Compare(a[:], b[:]) < 0 - } - if diff := cmp.Diff([]AUMHash{genesis.Hash(), intermediate.Hash(), leaf.Hash()}, hashes, cmpopts.SortSlices(hashesLess)); diff != "" { - t.Fatalf("AllAUMs() output differs (-want, +got):\n%s", diff) - } -} - -func TestMarkActiveChain(t *testing.T) { - type aumTemplate struct { - AUM AUM - } - - tcs := []struct { - name string - minChain int - chain []aumTemplate - expectLastActiveIdx int // expected lastActiveAncestor, corresponds to an index on chain. - }{ - { - name: "genesis", - minChain: 2, - chain: []aumTemplate{ - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - }, - expectLastActiveIdx: 0, - }, - { - name: "simple truncate", - minChain: 2, - chain: []aumTemplate{ - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - }, - expectLastActiveIdx: 1, - }, - { - name: "long truncate", - minChain: 5, - chain: []aumTemplate{ - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - }, - expectLastActiveIdx: 2, - }, - { - name: "truncate finding checkpoint", - minChain: 2, - chain: []aumTemplate{ - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMAddKey, Key: &Key{}}}, // Should keep searching upwards for a checkpoint - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, - }, - expectLastActiveIdx: 1, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - verdict := make(map[AUMHash]retainState, len(tc.chain)) - - // Build the state of the tailchonk for tests. - storage := &Mem{} - var prev AUMHash - for i := range tc.chain { - if !prev.IsZero() { - tc.chain[i].AUM.PrevAUMHash = make([]byte, len(prev[:])) - copy(tc.chain[i].AUM.PrevAUMHash, prev[:]) - } - if err := storage.CommitVerifiedAUMs([]AUM{tc.chain[i].AUM}); err != nil { - t.Fatal(err) - } - - h := tc.chain[i].AUM.Hash() - prev = h - verdict[h] = 0 - } - - got, err := markActiveChain(storage, verdict, tc.minChain, prev) - if err != nil { - t.Logf("state = %+v", verdict) - t.Fatalf("markActiveChain() failed: %v", err) - } - want := tc.chain[tc.expectLastActiveIdx].AUM.Hash() - if got != want { - t.Logf("state = %+v", verdict) - t.Errorf("lastActiveAncestor = %v, want %v", got, want) - } - - // Make sure the verdict array was marked correctly. - for i := range tc.chain { - h := tc.chain[i].AUM.Hash() - if i >= tc.expectLastActiveIdx { - if (verdict[h] & retainStateActive) == 0 { - t.Errorf("verdict[%v] = %v, want %v set", h, verdict[h], retainStateActive) - } - } else { - if (verdict[h] & retainStateCandidate) == 0 { - t.Errorf("verdict[%v] = %v, want %v set", h, verdict[h], retainStateCandidate) - } - } - } - }) - } -} - -func TestMarkDescendantAUMs(t *testing.T) { - c := newTestchain(t, ` - genesis -> B -> C -> C2 - | -> D - | -> E -> F -> G -> H - | -> E2 - - // tweak seeds so hashes arent identical - C.hashSeed = 1 - D.hashSeed = 2 - E.hashSeed = 3 - E2.hashSeed = 4 - `) - - verdict := make(map[AUMHash]retainState, len(c.AUMs)) - for _, a := range c.AUMs { - verdict[a.Hash()] = 0 - } - - // Mark E & C. - verdict[c.AUMHashes["C"]] = retainStateActive - verdict[c.AUMHashes["E"]] = retainStateActive - - if err := markDescendantAUMs(c.Chonk(), verdict); err != nil { - t.Errorf("markDescendantAUMs() failed: %v", err) - } - - // Make sure the descendants got marked. - hs := c.AUMHashes - for _, h := range []AUMHash{hs["C2"], hs["F"], hs["G"], hs["H"], hs["E2"]} { - if (verdict[h] & retainStateLeaf) == 0 { - t.Errorf("%v was not marked as a descendant", h) - } - } - for _, h := range []AUMHash{hs["genesis"], hs["B"], hs["D"]} { - if (verdict[h] & retainStateLeaf) != 0 { - t.Errorf("%v was marked as a descendant and shouldnt be", h) - } - } -} - -func TestMarkAncestorIntersectionAUMs(t *testing.T) { - fakeState := &State{ - Keys: []Key{{Kind: Key25519, Votes: 1}}, - DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, - } - - tcs := []struct { - name string - chain *testChain - verdicts map[string]retainState - initialAncestor string - wantAncestor string - wantRetained []string - wantDeleted []string - }{ - { - name: "genesis", - chain: newTestchain(t, ` - A - A.template = checkpoint`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), - initialAncestor: "A", - wantAncestor: "A", - verdicts: map[string]retainState{ - "A": retainStateActive, - }, - wantRetained: []string{"A"}, - }, - { - name: "no adjustment", - chain: newTestchain(t, ` - DEAD -> A -> B -> C - A.template = checkpoint - B.template = checkpoint`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), - initialAncestor: "A", - wantAncestor: "A", - verdicts: map[string]retainState{ - "A": retainStateActive, - "B": retainStateActive, - "C": retainStateActive, - "DEAD": retainStateCandidate, - }, - wantRetained: []string{"A", "B", "C"}, - wantDeleted: []string{"DEAD"}, - }, - { - name: "fork", - chain: newTestchain(t, ` - A -> B -> C -> D - | -> FORK - A.template = checkpoint - C.template = checkpoint - D.template = checkpoint - FORK.hashSeed = 2`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), - initialAncestor: "D", - wantAncestor: "C", - verdicts: map[string]retainState{ - "A": retainStateCandidate, - "B": retainStateCandidate, - "C": retainStateCandidate, - "D": retainStateActive, - "FORK": retainStateYoung, - }, - wantRetained: []string{"C", "D", "FORK"}, - wantDeleted: []string{"A", "B"}, - }, - { - name: "fork finding earlier checkpoint", - chain: newTestchain(t, ` - A -> B -> C -> D -> E -> F - | -> FORK - A.template = checkpoint - B.template = checkpoint - E.template = checkpoint - FORK.hashSeed = 2`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), - initialAncestor: "E", - wantAncestor: "B", - verdicts: map[string]retainState{ - "A": retainStateCandidate, - "B": retainStateCandidate, - "C": retainStateCandidate, - "D": retainStateCandidate, - "E": retainStateActive, - "F": retainStateActive, - "FORK": retainStateYoung, - }, - wantRetained: []string{"B", "C", "D", "E", "F", "FORK"}, - wantDeleted: []string{"A"}, - }, - { - name: "fork multi", - chain: newTestchain(t, ` - A -> B -> C -> D -> E - | -> DEADFORK - C -> FORK - A.template = checkpoint - C.template = checkpoint - D.template = checkpoint - E.template = checkpoint - FORK.hashSeed = 2 - DEADFORK.hashSeed = 3`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), - initialAncestor: "D", - wantAncestor: "C", - verdicts: map[string]retainState{ - "A": retainStateCandidate, - "B": retainStateCandidate, - "C": retainStateCandidate, - "D": retainStateActive, - "E": retainStateActive, - "FORK": retainStateYoung, - "DEADFORK": 0, - }, - wantRetained: []string{"C", "D", "E", "FORK"}, - wantDeleted: []string{"A", "B", "DEADFORK"}, - }, - { - name: "fork multi 2", - chain: newTestchain(t, ` - A -> B -> C -> D -> E -> F -> G - - F -> F1 - D -> F2 - B -> F3 - - A.template = checkpoint - B.template = checkpoint - D.template = checkpoint - F.template = checkpoint - F1.hashSeed = 2 - F2.hashSeed = 3 - F3.hashSeed = 4`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), - initialAncestor: "F", - wantAncestor: "B", - verdicts: map[string]retainState{ - "A": retainStateCandidate, - "B": retainStateCandidate, - "C": retainStateCandidate, - "D": retainStateCandidate, - "E": retainStateCandidate, - "F": retainStateActive, - "G": retainStateActive, - "F1": retainStateYoung, - "F2": retainStateYoung, - "F3": retainStateYoung, - }, - wantRetained: []string{"B", "C", "D", "E", "F", "G", "F1", "F2", "F3"}, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - verdict := make(map[AUMHash]retainState, len(tc.verdicts)) - for name, v := range tc.verdicts { - verdict[tc.chain.AUMHashes[name]] = v - } - - got, err := markAncestorIntersectionAUMs(tc.chain.Chonk(), verdict, tc.chain.AUMHashes[tc.initialAncestor]) - if err != nil { - t.Logf("state = %+v", verdict) - t.Fatalf("markAncestorIntersectionAUMs() failed: %v", err) - } - if want := tc.chain.AUMHashes[tc.wantAncestor]; got != want { - t.Logf("state = %+v", verdict) - t.Errorf("lastActiveAncestor = %v, want %v", got, want) - } - - for _, name := range tc.wantRetained { - h := tc.chain.AUMHashes[name] - if v := verdict[h]; v&retainAUMMask == 0 { - t.Errorf("AUM %q was not retained: verdict = %v", name, v) - } - } - for _, name := range tc.wantDeleted { - h := tc.chain.AUMHashes[name] - if v := verdict[h]; v&retainAUMMask != 0 { - t.Errorf("AUM %q was retained: verdict = %v", name, v) - } - } - - if t.Failed() { - for name, hash := range tc.chain.AUMHashes { - t.Logf("AUM[%q] = %v", name, hash) - } - } - }) - } -} - -type compactingChonkFake struct { - Mem - - aumAge map[AUMHash]time.Time - t *testing.T - wantDelete []AUMHash -} - -func (c *compactingChonkFake) AllAUMs() ([]AUMHash, error) { - out := make([]AUMHash, 0, len(c.Mem.aums)) - for h := range c.Mem.aums { - out = append(out, h) - } - return out, nil -} - -func (c *compactingChonkFake) CommitTime(hash AUMHash) (time.Time, error) { - return c.aumAge[hash], nil -} - -func (c *compactingChonkFake) PurgeAUMs(hashes []AUMHash) error { - lessHashes := func(a, b AUMHash) bool { - return bytes.Compare(a[:], b[:]) < 0 - } - if diff := cmp.Diff(c.wantDelete, hashes, cmpopts.SortSlices(lessHashes)); diff != "" { - c.t.Errorf("deletion set differs (-want, +got):\n%s", diff) - } - return nil -} - -// Avoid go vet complaining about copying a lock value -func cloneMem(src, dst *Mem) { - dst.l = sync.RWMutex{} - dst.aums = src.aums - dst.parentIndex = src.parentIndex - dst.lastActiveAncestor = src.lastActiveAncestor -} - -func TestCompact(t *testing.T) { - fakeState := &State{ - Keys: []Key{{Kind: Key25519, Votes: 1}}, - DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, - } - - // A & B are deleted because the new lastActiveAncestor advances beyond them. - // OLD is deleted because it does not match retention criteria, and - // though it is a descendant of the new lastActiveAncestor (C), it is not a - // descendant of a retained AUM. - // G, & H are retained as recent (MinChain=2) ancestors of HEAD. - // E & F are retained because they are between retained AUMs (G+) and - // their newest checkpoint ancestor. - // D is retained because it is the newest checkpoint ancestor from - // MinChain-retained AUMs. - // G2 is retained because it is a descendant of a retained AUM (G). - // F1 is retained because it is new enough by wall-clock time. - // F2 is retained because it is a descendant of a retained AUM (F1). - // C2 is retained because it is between an ancestor checkpoint and - // a retained AUM (F1). - // C is retained because it is the new lastActiveAncestor. It is the - // new lastActiveAncestor because it is the newest common checkpoint - // of all retained AUMs. - c := newTestchain(t, ` - A -> B -> C -> C2 -> D -> E -> F -> G -> H - | -> F1 -> F2 | -> G2 - | -> OLD - - // make {A,B,C,D} compaction candidates - A.template = checkpoint - B.template = checkpoint - C.template = checkpoint - D.template = checkpoint - - // tweak seeds of forks so hashes arent identical - F1.hashSeed = 1 - OLD.hashSeed = 2 - G2.hashSeed = 3 - `, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})) - - storage := &compactingChonkFake{ - aumAge: map[AUMHash]time.Time{(c.AUMHashes["F1"]): time.Now()}, - t: t, - wantDelete: []AUMHash{c.AUMHashes["A"], c.AUMHashes["B"], c.AUMHashes["OLD"]}, - } - - cloneMem(c.Chonk().(*Mem), &storage.Mem) - - lastActiveAncestor, err := Compact(storage, c.AUMHashes["H"], CompactionOptions{MinChain: 2, MinAge: time.Hour}) - if err != nil { - t.Errorf("Compact() failed: %v", err) - } - if lastActiveAncestor != c.AUMHashes["C"] { - t.Errorf("last active ancestor = %v, want %v", lastActiveAncestor, c.AUMHashes["C"]) - } - - if t.Failed() { - for name, hash := range c.AUMHashes { - t.Logf("AUM[%q] = %v", name, hash) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "golang.org/x/crypto/blake2s" +) + +// randHash derives a fake blake2s hash from the test name +// and the given seed. +func randHash(t *testing.T, seed int64) [blake2s.Size]byte { + var out [blake2s.Size]byte + testingRand(t, seed).Read(out[:]) + return out +} + +func TestImplementsChonk(t *testing.T) { + impls := []Chonk{&Mem{}, &FS{}} + t.Logf("chonks: %v", impls) +} + +func TestTailchonk_ChildAUMs(t *testing.T) { + for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { + t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { + parentHash := randHash(t, 1) + data := []AUM{ + { + MessageKind: AUMRemoveKey, + KeyID: []byte{1, 2}, + PrevAUMHash: parentHash[:], + }, + { + MessageKind: AUMRemoveKey, + KeyID: []byte{3, 4}, + PrevAUMHash: parentHash[:], + }, + } + + if err := chonk.CommitVerifiedAUMs(data); err != nil { + t.Fatalf("CommitVerifiedAUMs failed: %v", err) + } + stored, err := chonk.ChildAUMs(parentHash) + if err != nil { + t.Fatalf("ChildAUMs failed: %v", err) + } + if diff := cmp.Diff(data, stored); diff != "" { + t.Errorf("stored AUM differs (-want, +got):\n%s", diff) + } + }) + } +} + +func TestTailchonk_AUMMissing(t *testing.T) { + for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { + t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { + var notExists AUMHash + notExists[:][0] = 42 + if _, err := chonk.AUM(notExists); err != os.ErrNotExist { + t.Errorf("chonk.AUM(notExists).err = %v, want %v", err, os.ErrNotExist) + } + }) + } +} + +func TestTailchonkMem_Orphans(t *testing.T) { + chonk := Mem{} + + parentHash := randHash(t, 1) + orphan := AUM{MessageKind: AUMNoOp} + aums := []AUM{ + orphan, + // A parent is specified, so we shouldnt see it in GetOrphans() + { + MessageKind: AUMRemoveKey, + KeyID: []byte{3, 4}, + PrevAUMHash: parentHash[:], + }, + } + if err := chonk.CommitVerifiedAUMs(aums); err != nil { + t.Fatalf("CommitVerifiedAUMs failed: %v", err) + } + + stored, err := chonk.Orphans() + if err != nil { + t.Fatalf("Orphans failed: %v", err) + } + if diff := cmp.Diff([]AUM{orphan}, stored); diff != "" { + t.Errorf("stored AUM differs (-want, +got):\n%s", diff) + } +} + +func TestTailchonk_ReadChainFromHead(t *testing.T) { + for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { + + t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { + genesis := AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}} + gHash := genesis.Hash() + intermediate := AUM{PrevAUMHash: gHash[:]} + iHash := intermediate.Hash() + leaf := AUM{PrevAUMHash: iHash[:]} + + commitSet := []AUM{ + genesis, + intermediate, + leaf, + } + if err := chonk.CommitVerifiedAUMs(commitSet); err != nil { + t.Fatalf("CommitVerifiedAUMs failed: %v", err) + } + // t.Logf("genesis hash = %X", genesis.Hash()) + // t.Logf("intermediate hash = %X", intermediate.Hash()) + // t.Logf("leaf hash = %X", leaf.Hash()) + + // Read the chain from the leaf backwards. + gotLeafs, err := chonk.Heads() + if err != nil { + t.Fatalf("Heads failed: %v", err) + } + if diff := cmp.Diff([]AUM{leaf}, gotLeafs); diff != "" { + t.Fatalf("leaf AUM differs (-want, +got):\n%s", diff) + } + + parent, _ := gotLeafs[0].Parent() + gotIntermediate, err := chonk.AUM(parent) + if err != nil { + t.Fatalf("AUM() failed: %v", err) + } + if diff := cmp.Diff(intermediate, gotIntermediate); diff != "" { + t.Errorf("intermediate AUM differs (-want, +got):\n%s", diff) + } + + parent, _ = gotIntermediate.Parent() + gotGenesis, err := chonk.AUM(parent) + if err != nil { + t.Fatalf("AUM() failed: %v", err) + } + if diff := cmp.Diff(genesis, gotGenesis); diff != "" { + t.Errorf("genesis AUM differs (-want, +got):\n%s", diff) + } + }) + } +} + +func TestTailchonkFS_Commit(t *testing.T) { + chonk := &FS{base: t.TempDir()} + parentHash := randHash(t, 1) + aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} + + if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { + t.Fatal(err) + } + + dir, base := chonk.aumDir(aum.Hash()) + if got, want := dir, filepath.Join(chonk.base, "PD"); got != want { + t.Errorf("aum dir=%s, want %s", got, want) + } + if want := "PD57DVP6GKC76OOZMXFFZUSOEFQXOLAVT7N2ZM5KB3HDIMCANF4A"; base != want { + t.Errorf("aum base=%s, want %s", base, want) + } + if _, err := os.Stat(filepath.Join(dir, base)); err != nil { + t.Errorf("stat of AUM file failed: %v", err) + } + if _, err := os.Stat(filepath.Join(chonk.base, "M7", "M7LL2NDB4NKCZIUPVS6RDM2GUOIMW6EEAFVBWMVCPUANQJPHT3SQ")); err != nil { + t.Errorf("stat of AUM parent failed: %v", err) + } + + info, err := chonk.get(aum.Hash()) + if err != nil { + t.Fatal(err) + } + if info.PurgedUnix > 0 { + t.Errorf("recently-created AUM PurgedUnix = %d, want 0", info.PurgedUnix) + } +} + +func TestTailchonkFS_CommitTime(t *testing.T) { + chonk := &FS{base: t.TempDir()} + parentHash := randHash(t, 1) + aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} + + if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { + t.Fatal(err) + } + ct, err := chonk.CommitTime(aum.Hash()) + if err != nil { + t.Fatalf("CommitTime() failed: %v", err) + } + if ct.Before(time.Now().Add(-time.Minute)) || ct.After(time.Now().Add(time.Minute)) { + t.Errorf("commit time was wrong: %v more than a minute off from now (%v)", ct, time.Now()) + } +} + +func TestTailchonkFS_PurgeAUMs(t *testing.T) { + chonk := &FS{base: t.TempDir()} + parentHash := randHash(t, 1) + aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} + + if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { + t.Fatal(err) + } + if err := chonk.PurgeAUMs([]AUMHash{aum.Hash()}); err != nil { + t.Fatal(err) + } + + if _, err := chonk.AUM(aum.Hash()); err != os.ErrNotExist { + t.Errorf("AUM() on purged AUM returned err = %v, want ErrNotExist", err) + } + + info, err := chonk.get(aum.Hash()) + if err != nil { + t.Fatal(err) + } + if info.PurgedUnix == 0 { + t.Errorf("recently-created AUM PurgedUnix = %d, want non-zero", info.PurgedUnix) + } +} + +func TestTailchonkFS_AllAUMs(t *testing.T) { + chonk := &FS{base: t.TempDir()} + genesis := AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}} + gHash := genesis.Hash() + intermediate := AUM{PrevAUMHash: gHash[:]} + iHash := intermediate.Hash() + leaf := AUM{PrevAUMHash: iHash[:]} + + commitSet := []AUM{ + genesis, + intermediate, + leaf, + } + if err := chonk.CommitVerifiedAUMs(commitSet); err != nil { + t.Fatalf("CommitVerifiedAUMs failed: %v", err) + } + + hashes, err := chonk.AllAUMs() + if err != nil { + t.Fatal(err) + } + hashesLess := func(a, b AUMHash) bool { + return bytes.Compare(a[:], b[:]) < 0 + } + if diff := cmp.Diff([]AUMHash{genesis.Hash(), intermediate.Hash(), leaf.Hash()}, hashes, cmpopts.SortSlices(hashesLess)); diff != "" { + t.Fatalf("AllAUMs() output differs (-want, +got):\n%s", diff) + } +} + +func TestMarkActiveChain(t *testing.T) { + type aumTemplate struct { + AUM AUM + } + + tcs := []struct { + name string + minChain int + chain []aumTemplate + expectLastActiveIdx int // expected lastActiveAncestor, corresponds to an index on chain. + }{ + { + name: "genesis", + minChain: 2, + chain: []aumTemplate{ + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + }, + expectLastActiveIdx: 0, + }, + { + name: "simple truncate", + minChain: 2, + chain: []aumTemplate{ + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + }, + expectLastActiveIdx: 1, + }, + { + name: "long truncate", + minChain: 5, + chain: []aumTemplate{ + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + }, + expectLastActiveIdx: 2, + }, + { + name: "truncate finding checkpoint", + minChain: 2, + chain: []aumTemplate{ + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMAddKey, Key: &Key{}}}, // Should keep searching upwards for a checkpoint + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, + }, + expectLastActiveIdx: 1, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + verdict := make(map[AUMHash]retainState, len(tc.chain)) + + // Build the state of the tailchonk for tests. + storage := &Mem{} + var prev AUMHash + for i := range tc.chain { + if !prev.IsZero() { + tc.chain[i].AUM.PrevAUMHash = make([]byte, len(prev[:])) + copy(tc.chain[i].AUM.PrevAUMHash, prev[:]) + } + if err := storage.CommitVerifiedAUMs([]AUM{tc.chain[i].AUM}); err != nil { + t.Fatal(err) + } + + h := tc.chain[i].AUM.Hash() + prev = h + verdict[h] = 0 + } + + got, err := markActiveChain(storage, verdict, tc.minChain, prev) + if err != nil { + t.Logf("state = %+v", verdict) + t.Fatalf("markActiveChain() failed: %v", err) + } + want := tc.chain[tc.expectLastActiveIdx].AUM.Hash() + if got != want { + t.Logf("state = %+v", verdict) + t.Errorf("lastActiveAncestor = %v, want %v", got, want) + } + + // Make sure the verdict array was marked correctly. + for i := range tc.chain { + h := tc.chain[i].AUM.Hash() + if i >= tc.expectLastActiveIdx { + if (verdict[h] & retainStateActive) == 0 { + t.Errorf("verdict[%v] = %v, want %v set", h, verdict[h], retainStateActive) + } + } else { + if (verdict[h] & retainStateCandidate) == 0 { + t.Errorf("verdict[%v] = %v, want %v set", h, verdict[h], retainStateCandidate) + } + } + } + }) + } +} + +func TestMarkDescendantAUMs(t *testing.T) { + c := newTestchain(t, ` + genesis -> B -> C -> C2 + | -> D + | -> E -> F -> G -> H + | -> E2 + + // tweak seeds so hashes arent identical + C.hashSeed = 1 + D.hashSeed = 2 + E.hashSeed = 3 + E2.hashSeed = 4 + `) + + verdict := make(map[AUMHash]retainState, len(c.AUMs)) + for _, a := range c.AUMs { + verdict[a.Hash()] = 0 + } + + // Mark E & C. + verdict[c.AUMHashes["C"]] = retainStateActive + verdict[c.AUMHashes["E"]] = retainStateActive + + if err := markDescendantAUMs(c.Chonk(), verdict); err != nil { + t.Errorf("markDescendantAUMs() failed: %v", err) + } + + // Make sure the descendants got marked. + hs := c.AUMHashes + for _, h := range []AUMHash{hs["C2"], hs["F"], hs["G"], hs["H"], hs["E2"]} { + if (verdict[h] & retainStateLeaf) == 0 { + t.Errorf("%v was not marked as a descendant", h) + } + } + for _, h := range []AUMHash{hs["genesis"], hs["B"], hs["D"]} { + if (verdict[h] & retainStateLeaf) != 0 { + t.Errorf("%v was marked as a descendant and shouldnt be", h) + } + } +} + +func TestMarkAncestorIntersectionAUMs(t *testing.T) { + fakeState := &State{ + Keys: []Key{{Kind: Key25519, Votes: 1}}, + DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, + } + + tcs := []struct { + name string + chain *testChain + verdicts map[string]retainState + initialAncestor string + wantAncestor string + wantRetained []string + wantDeleted []string + }{ + { + name: "genesis", + chain: newTestchain(t, ` + A + A.template = checkpoint`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + initialAncestor: "A", + wantAncestor: "A", + verdicts: map[string]retainState{ + "A": retainStateActive, + }, + wantRetained: []string{"A"}, + }, + { + name: "no adjustment", + chain: newTestchain(t, ` + DEAD -> A -> B -> C + A.template = checkpoint + B.template = checkpoint`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + initialAncestor: "A", + wantAncestor: "A", + verdicts: map[string]retainState{ + "A": retainStateActive, + "B": retainStateActive, + "C": retainStateActive, + "DEAD": retainStateCandidate, + }, + wantRetained: []string{"A", "B", "C"}, + wantDeleted: []string{"DEAD"}, + }, + { + name: "fork", + chain: newTestchain(t, ` + A -> B -> C -> D + | -> FORK + A.template = checkpoint + C.template = checkpoint + D.template = checkpoint + FORK.hashSeed = 2`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + initialAncestor: "D", + wantAncestor: "C", + verdicts: map[string]retainState{ + "A": retainStateCandidate, + "B": retainStateCandidate, + "C": retainStateCandidate, + "D": retainStateActive, + "FORK": retainStateYoung, + }, + wantRetained: []string{"C", "D", "FORK"}, + wantDeleted: []string{"A", "B"}, + }, + { + name: "fork finding earlier checkpoint", + chain: newTestchain(t, ` + A -> B -> C -> D -> E -> F + | -> FORK + A.template = checkpoint + B.template = checkpoint + E.template = checkpoint + FORK.hashSeed = 2`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + initialAncestor: "E", + wantAncestor: "B", + verdicts: map[string]retainState{ + "A": retainStateCandidate, + "B": retainStateCandidate, + "C": retainStateCandidate, + "D": retainStateCandidate, + "E": retainStateActive, + "F": retainStateActive, + "FORK": retainStateYoung, + }, + wantRetained: []string{"B", "C", "D", "E", "F", "FORK"}, + wantDeleted: []string{"A"}, + }, + { + name: "fork multi", + chain: newTestchain(t, ` + A -> B -> C -> D -> E + | -> DEADFORK + C -> FORK + A.template = checkpoint + C.template = checkpoint + D.template = checkpoint + E.template = checkpoint + FORK.hashSeed = 2 + DEADFORK.hashSeed = 3`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + initialAncestor: "D", + wantAncestor: "C", + verdicts: map[string]retainState{ + "A": retainStateCandidate, + "B": retainStateCandidate, + "C": retainStateCandidate, + "D": retainStateActive, + "E": retainStateActive, + "FORK": retainStateYoung, + "DEADFORK": 0, + }, + wantRetained: []string{"C", "D", "E", "FORK"}, + wantDeleted: []string{"A", "B", "DEADFORK"}, + }, + { + name: "fork multi 2", + chain: newTestchain(t, ` + A -> B -> C -> D -> E -> F -> G + + F -> F1 + D -> F2 + B -> F3 + + A.template = checkpoint + B.template = checkpoint + D.template = checkpoint + F.template = checkpoint + F1.hashSeed = 2 + F2.hashSeed = 3 + F3.hashSeed = 4`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + initialAncestor: "F", + wantAncestor: "B", + verdicts: map[string]retainState{ + "A": retainStateCandidate, + "B": retainStateCandidate, + "C": retainStateCandidate, + "D": retainStateCandidate, + "E": retainStateCandidate, + "F": retainStateActive, + "G": retainStateActive, + "F1": retainStateYoung, + "F2": retainStateYoung, + "F3": retainStateYoung, + }, + wantRetained: []string{"B", "C", "D", "E", "F", "G", "F1", "F2", "F3"}, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + verdict := make(map[AUMHash]retainState, len(tc.verdicts)) + for name, v := range tc.verdicts { + verdict[tc.chain.AUMHashes[name]] = v + } + + got, err := markAncestorIntersectionAUMs(tc.chain.Chonk(), verdict, tc.chain.AUMHashes[tc.initialAncestor]) + if err != nil { + t.Logf("state = %+v", verdict) + t.Fatalf("markAncestorIntersectionAUMs() failed: %v", err) + } + if want := tc.chain.AUMHashes[tc.wantAncestor]; got != want { + t.Logf("state = %+v", verdict) + t.Errorf("lastActiveAncestor = %v, want %v", got, want) + } + + for _, name := range tc.wantRetained { + h := tc.chain.AUMHashes[name] + if v := verdict[h]; v&retainAUMMask == 0 { + t.Errorf("AUM %q was not retained: verdict = %v", name, v) + } + } + for _, name := range tc.wantDeleted { + h := tc.chain.AUMHashes[name] + if v := verdict[h]; v&retainAUMMask != 0 { + t.Errorf("AUM %q was retained: verdict = %v", name, v) + } + } + + if t.Failed() { + for name, hash := range tc.chain.AUMHashes { + t.Logf("AUM[%q] = %v", name, hash) + } + } + }) + } +} + +type compactingChonkFake struct { + Mem + + aumAge map[AUMHash]time.Time + t *testing.T + wantDelete []AUMHash +} + +func (c *compactingChonkFake) AllAUMs() ([]AUMHash, error) { + out := make([]AUMHash, 0, len(c.Mem.aums)) + for h := range c.Mem.aums { + out = append(out, h) + } + return out, nil +} + +func (c *compactingChonkFake) CommitTime(hash AUMHash) (time.Time, error) { + return c.aumAge[hash], nil +} + +func (c *compactingChonkFake) PurgeAUMs(hashes []AUMHash) error { + lessHashes := func(a, b AUMHash) bool { + return bytes.Compare(a[:], b[:]) < 0 + } + if diff := cmp.Diff(c.wantDelete, hashes, cmpopts.SortSlices(lessHashes)); diff != "" { + c.t.Errorf("deletion set differs (-want, +got):\n%s", diff) + } + return nil +} + +// Avoid go vet complaining about copying a lock value +func cloneMem(src, dst *Mem) { + dst.l = sync.RWMutex{} + dst.aums = src.aums + dst.parentIndex = src.parentIndex + dst.lastActiveAncestor = src.lastActiveAncestor +} + +func TestCompact(t *testing.T) { + fakeState := &State{ + Keys: []Key{{Kind: Key25519, Votes: 1}}, + DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, + } + + // A & B are deleted because the new lastActiveAncestor advances beyond them. + // OLD is deleted because it does not match retention criteria, and + // though it is a descendant of the new lastActiveAncestor (C), it is not a + // descendant of a retained AUM. + // G, & H are retained as recent (MinChain=2) ancestors of HEAD. + // E & F are retained because they are between retained AUMs (G+) and + // their newest checkpoint ancestor. + // D is retained because it is the newest checkpoint ancestor from + // MinChain-retained AUMs. + // G2 is retained because it is a descendant of a retained AUM (G). + // F1 is retained because it is new enough by wall-clock time. + // F2 is retained because it is a descendant of a retained AUM (F1). + // C2 is retained because it is between an ancestor checkpoint and + // a retained AUM (F1). + // C is retained because it is the new lastActiveAncestor. It is the + // new lastActiveAncestor because it is the newest common checkpoint + // of all retained AUMs. + c := newTestchain(t, ` + A -> B -> C -> C2 -> D -> E -> F -> G -> H + | -> F1 -> F2 | -> G2 + | -> OLD + + // make {A,B,C,D} compaction candidates + A.template = checkpoint + B.template = checkpoint + C.template = checkpoint + D.template = checkpoint + + // tweak seeds of forks so hashes arent identical + F1.hashSeed = 1 + OLD.hashSeed = 2 + G2.hashSeed = 3 + `, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})) + + storage := &compactingChonkFake{ + aumAge: map[AUMHash]time.Time{(c.AUMHashes["F1"]): time.Now()}, + t: t, + wantDelete: []AUMHash{c.AUMHashes["A"], c.AUMHashes["B"], c.AUMHashes["OLD"]}, + } + + cloneMem(c.Chonk().(*Mem), &storage.Mem) + + lastActiveAncestor, err := Compact(storage, c.AUMHashes["H"], CompactionOptions{MinChain: 2, MinAge: time.Hour}) + if err != nil { + t.Errorf("Compact() failed: %v", err) + } + if lastActiveAncestor != c.AUMHashes["C"] { + t.Errorf("last active ancestor = %v, want %v", lastActiveAncestor, c.AUMHashes["C"]) + } + + if t.Failed() { + for name, hash := range c.AUMHashes { + t.Logf("AUM[%q] = %v", name, hash) + } + } +} diff --git a/tka/tka_test.go b/tka/tka_test.go index 3438a4016..9e3c4e79d 100644 --- a/tka/tka_test.go +++ b/tka/tka_test.go @@ -1,654 +1,654 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tka - -import ( - "bytes" - "testing" - - "github.com/google/go-cmp/cmp" - "tailscale.com/types/key" - "tailscale.com/types/tkatype" -) - -func TestComputeChainCandidates(t *testing.T) { - c := newTestchain(t, ` - G1 -> I1 -> I2 -> I3 -> L2 - | -> L1 | -> L3 - - G2 -> L4 - - // We tweak these AUMs so they are different hashes. - G2.hashSeed = 2 - L1.hashSeed = 2 - L3.hashSeed = 2 - L4.hashSeed = 3 - `) - // Should result in 4 chains: - // G1->L1, G1->L2, G1->L3, G2->L4 - - i1H := c.AUMHashes["I1"] - got, err := computeChainCandidates(c.Chonk(), &i1H, 50) - if err != nil { - t.Fatalf("computeChainCandidates() failed: %v", err) - } - - want := []chain{ - {Oldest: c.AUMs["G2"], Head: c.AUMs["L4"]}, - {Oldest: c.AUMs["G1"], Head: c.AUMs["L3"], chainsThroughActive: true}, - {Oldest: c.AUMs["G1"], Head: c.AUMs["L1"], chainsThroughActive: true}, - {Oldest: c.AUMs["G1"], Head: c.AUMs["L2"], chainsThroughActive: true}, - } - if diff := cmp.Diff(want, got, cmp.AllowUnexported(chain{})); diff != "" { - t.Errorf("chains differ (-want, +got):\n%s", diff) - } -} - -func TestForkResolutionHash(t *testing.T) { - c := newTestchain(t, ` - G1 -> L1 - | -> L2 - - // tweak hashes so L1 & L2 are not identical - L1.hashSeed = 2 - L2.hashSeed = 3 - `) - - got, err := computeActiveChain(c.Chonk(), nil, 50) - if err != nil { - t.Fatalf("computeActiveChain() failed: %v", err) - } - - // The fork with the lowest AUM hash should have been chosen. - l1H := c.AUMHashes["L1"] - l2H := c.AUMHashes["L2"] - want := l1H - if bytes.Compare(l2H[:], l1H[:]) < 0 { - want = l2H - } - - if got := got.Head.Hash(); got != want { - t.Errorf("head was %x, want %x", got, want) - } -} - -func TestForkResolutionSigWeight(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G1 -> L1 - | -> L2 - - G1.template = addKey - L1.hashSeed = 11 - L2.signedWith = key - `, - optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}), - optKey("key", key, priv)) - - l1H := c.AUMHashes["L1"] - l2H := c.AUMHashes["L2"] - if bytes.Compare(l2H[:], l1H[:]) < 0 { - t.Fatal("failed assert: h(l1) > h(l2)\nTweak hashSeed till this passes") - } - - got, err := computeActiveChain(c.Chonk(), nil, 50) - if err != nil { - t.Fatalf("computeActiveChain() failed: %v", err) - } - - // Based on the hash, l1H should be chosen. - // But based on the signature weight (which has higher - // precedence), it should be l2H - want := l2H - if got := got.Head.Hash(); got != want { - t.Errorf("head was %x, want %x", got, want) - } -} - -func TestForkResolutionMessageType(t *testing.T) { - pub, _ := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G1 -> L1 - | -> L2 - | -> L3 - - G1.template = addKey - L1.hashSeed = 11 - L2.template = removeKey - L3.hashSeed = 18 - `, - optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}), - optTemplate("removeKey", AUM{MessageKind: AUMRemoveKey, KeyID: key.MustID()})) - - l1H := c.AUMHashes["L1"] - l2H := c.AUMHashes["L2"] - l3H := c.AUMHashes["L3"] - if bytes.Compare(l2H[:], l1H[:]) < 0 { - t.Fatal("failed assert: h(l1) > h(l2)\nTweak hashSeed till this passes") - } - if bytes.Compare(l2H[:], l3H[:]) < 0 { - t.Fatal("failed assert: h(l3) > h(l2)\nTweak hashSeed till this passes") - } - - got, err := computeActiveChain(c.Chonk(), nil, 50) - if err != nil { - t.Fatalf("computeActiveChain() failed: %v", err) - } - - // Based on the hash, L1 or L3 should be chosen. - // But based on the preference for AUMRemoveKey messages, - // it should be L2. - want := l2H - if got := got.Head.Hash(); got != want { - t.Errorf("head was %x, want %x", got, want) - } -} - -func TestComputeStateAt(t *testing.T) { - pub, _ := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G1 -> I1 -> I2 - I1.template = addKey - `, - optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key})) - - // G1 is before the key, so there shouldn't be a key there. - state, err := computeStateAt(c.Chonk(), 500, c.AUMHashes["G1"]) - if err != nil { - t.Fatalf("computeStateAt(G1) failed: %v", err) - } - if _, err := state.GetKey(key.MustID()); err != ErrNoSuchKey { - t.Errorf("expected key to be missing: err = %v", err) - } - if *state.LastAUMHash != c.AUMHashes["G1"] { - t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, c.AUMHashes["G1"]) - } - - // I1 & I2 are after the key, so the computed state should contain - // the key. - for _, wantHash := range []AUMHash{c.AUMHashes["I1"], c.AUMHashes["I2"]} { - state, err = computeStateAt(c.Chonk(), 500, wantHash) - if err != nil { - t.Fatalf("computeStateAt(%X) failed: %v", wantHash, err) - } - if *state.LastAUMHash != wantHash { - t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, wantHash) - } - if _, err := state.GetKey(key.MustID()); err != nil { - t.Errorf("expected key to be present at state: err = %v", err) - } - } -} - -// fakeAUM generates an AUM structure based on the template. -// If parent is provided, PrevAUMHash is set to that value. -// -// If template is an AUM, the returned AUM is based on that. -// If template is an int, a NOOP AUM is returned, and the -// provided int can be used to tweak the resulting hash (needed -// for tests you want one AUM to be 'lower' than another, so that -// that chain is taken based on fork resolution rules). -func fakeAUM(t *testing.T, template any, parent *AUMHash) (AUM, AUMHash) { - if seed, ok := template.(int); ok { - a := AUM{MessageKind: AUMNoOp, KeyID: []byte{byte(seed)}} - if parent != nil { - a.PrevAUMHash = (*parent)[:] - } - h := a.Hash() - return a, h - } - - if a, ok := template.(AUM); ok { - if parent != nil { - a.PrevAUMHash = (*parent)[:] - } - h := a.Hash() - return a, h - } - - panic("template must be an int or an AUM") -} - -func TestOpenAuthority(t *testing.T) { - pub, _ := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - // /- L1 - // G1 - I1 - I2 - I3 -L2 - // \-L3 - // G2 - L4 - // - // We set the previous-known ancestor to G1, so the - // ancestor to start from should be G1. - g1, g1H := fakeAUM(t, AUM{MessageKind: AUMAddKey, Key: &key}, nil) - i1, i1H := fakeAUM(t, 2, &g1H) // AUM{MessageKind: AUMAddKey, Key: &key2} - l1, l1H := fakeAUM(t, 13, &i1H) - - i2, i2H := fakeAUM(t, 2, &i1H) - i3, i3H := fakeAUM(t, 5, &i2H) - l2, l2H := fakeAUM(t, AUM{MessageKind: AUMNoOp, KeyID: []byte{7}, Signatures: []tkatype.Signature{{KeyID: key.MustID()}}}, &i3H) - l3, l3H := fakeAUM(t, 4, &i3H) - - g2, g2H := fakeAUM(t, 8, nil) - l4, _ := fakeAUM(t, 9, &g2H) - - // We make sure that I2 has a lower hash than L1, so - // it should take that path rather than L1. - if bytes.Compare(l1H[:], i2H[:]) < 0 { - t.Fatal("failed assert: h(i2) > h(l1)\nTweak parameters to fakeAUM till this passes") - } - // We make sure L2 has a signature with key, so it should - // take that path over L3. We assert that the L3 hash - // is less than L2 so the test will fail if the signature - // preference logic is broken. - if bytes.Compare(l2H[:], l3H[:]) < 0 { - t.Fatal("failed assert: h(l3) > h(l2)\nTweak parameters to fakeAUM till this passes") - } - - // Construct the state of durable storage. - chonk := &Mem{} - err := chonk.CommitVerifiedAUMs([]AUM{g1, i1, l1, i2, i3, l2, l3, g2, l4}) - if err != nil { - t.Fatal(err) - } - chonk.SetLastActiveAncestor(i1H) - - a, err := Open(chonk) - if err != nil { - t.Fatalf("New() failed: %v", err) - } - // Should include the key added in G1 - if _, err := a.state.GetKey(key.MustID()); err != nil { - t.Errorf("missing G1 key: %v", err) - } - // The head of the chain should be L2. - if a.Head() != l2H { - t.Errorf("head was %x, want %x", a.state.LastAUMHash, l2H) - } -} - -func TestOpenAuthority_EmptyErrors(t *testing.T) { - _, err := Open(&Mem{}) - if err == nil { - t.Error("Expected an error initializing an empty authority, got nil") - } -} - -func TestAuthorityHead(t *testing.T) { - c := newTestchain(t, ` - G1 -> L1 - | -> L2 - - L1.hashSeed = 2 - `) - - a, _ := Open(c.Chonk()) - if got, want := a.head.Hash(), a.Head(); got != want { - t.Errorf("Hash() returned %x, want %x", got, want) - } -} - -func TestAuthorityValidDisablement(t *testing.T) { - pub, _ := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - c := newTestchain(t, ` - G1 -> L1 - - G1.template = genesis - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - ) - - a, _ := Open(c.Chonk()) - if valid := a.ValidDisablement([]byte{1, 2, 3}); !valid { - t.Error("ValidDisablement() returned false, want true") - } -} - -func TestCreateBootstrapAuthority(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - a1, genesisAUM, err := Create(&Mem{}, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) - if err != nil { - t.Fatalf("Create() failed: %v", err) - } - - a2, err := Bootstrap(&Mem{}, genesisAUM) - if err != nil { - t.Fatalf("Bootstrap() failed: %v", err) - } - - if a1.Head() != a2.Head() { - t.Fatal("created and bootstrapped authority differ") - } - - // Both authorities should trust the key laid down in the genesis state. - if !a1.KeyTrusted(key.MustID()) { - t.Error("a1 did not trust genesis key") - } - if !a2.KeyTrusted(key.MustID()) { - t.Error("a2 did not trust genesis key") - } -} - -func TestAuthorityInformNonLinear(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G1 -> L1 - | -> L2 -> L3 - | -> L4 -> L5 - - G1.template = genesis - L1.hashSeed = 3 - L2.hashSeed = 2 - L4.hashSeed = 2 - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optKey("key", key, priv), - optSignAllUsing("key")) - - storage := &Mem{} - a, err := Bootstrap(storage, c.AUMs["G1"]) - if err != nil { - t.Fatalf("Bootstrap() failed: %v", err) - } - - // L2 does not chain from L1, disabling the isHeadChain optimization - // and forcing Inform() to take the slow path. - informAUMs := []AUM{c.AUMs["L1"], c.AUMs["L2"], c.AUMs["L3"], c.AUMs["L4"], c.AUMs["L5"]} - - if err := a.Inform(storage, informAUMs); err != nil { - t.Fatalf("Inform() failed: %v", err) - } - for i, update := range informAUMs { - stored, err := storage.AUM(update.Hash()) - if err != nil { - t.Errorf("reading stored update %d: %v", i, err) - continue - } - if diff := cmp.Diff(update, stored); diff != "" { - t.Errorf("update %d differs (-want, +got):\n%s", i, diff) - } - } - - if a.Head() != c.AUMHashes["L3"] { - t.Fatal("authority did not converge to correct AUM") - } -} - -func TestAuthorityInformLinear(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G1 -> L1 -> L2 -> L3 - - G1.template = genesis - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optKey("key", key, priv), - optSignAllUsing("key")) - - storage := &Mem{} - a, err := Bootstrap(storage, c.AUMs["G1"]) - if err != nil { - t.Fatalf("Bootstrap() failed: %v", err) - } - - informAUMs := []AUM{c.AUMs["L1"], c.AUMs["L2"], c.AUMs["L3"]} - - if err := a.Inform(storage, informAUMs); err != nil { - t.Fatalf("Inform() failed: %v", err) - } - for i, update := range informAUMs { - stored, err := storage.AUM(update.Hash()) - if err != nil { - t.Errorf("reading stored update %d: %v", i, err) - continue - } - if diff := cmp.Diff(update, stored); diff != "" { - t.Errorf("update %d differs (-want, +got):\n%s", i, diff) - } - } - - if a.Head() != c.AUMHashes["L3"] { - t.Fatal("authority did not converge to correct AUM") - } -} - -func TestInteropWithNLKey(t *testing.T) { - priv1 := key.NewNLPrivate() - pub1 := priv1.Public() - pub2 := key.NewNLPrivate().Public() - pub3 := key.NewNLPrivate().Public() - - a, _, err := Create(&Mem{}, State{ - Keys: []Key{ - { - Kind: Key25519, - Votes: 1, - Public: pub1.KeyID(), - }, - { - Kind: Key25519, - Votes: 1, - Public: pub2.KeyID(), - }, - }, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, priv1) - if err != nil { - t.Errorf("tka.Create: %v", err) - return - } - - if !a.KeyTrusted(pub1.KeyID()) { - t.Error("pub1 want trusted, got untrusted") - } - if !a.KeyTrusted(pub2.KeyID()) { - t.Error("pub2 want trusted, got untrusted") - } - if a.KeyTrusted(pub3.KeyID()) { - t.Error("pub3 want untrusted, got trusted") - } -} - -func TestAuthorityCompact(t *testing.T) { - pub, priv := testingKey25519(t, 1) - key := Key{Kind: Key25519, Public: pub, Votes: 2} - - c := newTestchain(t, ` - G -> A -> B -> C -> D -> E - - G.template = genesis - C.template = checkpoint2 - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optTemplate("checkpoint2", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optKey("key", key, priv), - optSignAllUsing("key")) - - storage := &FS{base: t.TempDir()} - a, err := Bootstrap(storage, c.AUMs["G"]) - if err != nil { - t.Fatalf("Bootstrap() failed: %v", err) - } - a.Inform(storage, []AUM{c.AUMs["A"], c.AUMs["B"], c.AUMs["C"], c.AUMs["D"], c.AUMs["E"]}) - - // Should compact down to C -> D -> E - if err := a.Compact(storage, CompactionOptions{MinChain: 2, MinAge: 1}); err != nil { - t.Fatal(err) - } - if a.oldestAncestor.Hash() != c.AUMHashes["C"] { - t.Errorf("ancestor = %v, want %v", a.oldestAncestor.Hash(), c.AUMHashes["C"]) - } - - // Make sure the stored authority is still openable and resolves to the same state. - stored, err := Open(storage) - if err != nil { - t.Fatalf("Failed to open stored authority: %v", err) - } - if stored.Head() != a.Head() { - t.Errorf("Stored authority head differs: head = %v, want %v", stored.Head(), a.Head()) - } - t.Logf("original ancestor = %v", c.AUMHashes["G"]) - if anc, _ := storage.LastActiveAncestor(); *anc != c.AUMHashes["C"] { - t.Errorf("ancestor = %v, want %v", anc, c.AUMHashes["C"]) - } -} - -func TestFindParentForRewrite(t *testing.T) { - pub, _ := testingKey25519(t, 1) - k1 := Key{Kind: Key25519, Public: pub, Votes: 1} - - pub2, _ := testingKey25519(t, 2) - k2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - k2ID, _ := k2.ID() - pub3, _ := testingKey25519(t, 3) - k3 := Key{Kind: Key25519, Public: pub3, Votes: 1} - - c := newTestchain(t, ` - A -> B -> C -> D -> E - A.template = genesis - B.template = add2 - C.template = add3 - D.template = remove2 - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{k1}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optTemplate("add2", AUM{MessageKind: AUMAddKey, Key: &k2}), - optTemplate("add3", AUM{MessageKind: AUMAddKey, Key: &k3}), - optTemplate("remove2", AUM{MessageKind: AUMRemoveKey, KeyID: k2ID})) - - a, err := Open(c.Chonk()) - if err != nil { - t.Fatal(err) - } - - // k1 was trusted at genesis, so there's no better rewrite parent - // than the genesis. - k1ID, _ := k1.ID() - k1P, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k1ID}, k1ID) - if err != nil { - t.Fatalf("FindParentForRewrite(k1) failed: %v", err) - } - if k1P != a.oldestAncestor.Hash() { - t.Errorf("FindParentForRewrite(k1) = %v, want %v", k1P, a.oldestAncestor.Hash()) - } - - // k3 was trusted at C, so B would be an ideal rewrite point. - k3ID, _ := k3.ID() - k3P, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k3ID}, k1ID) - if err != nil { - t.Fatalf("FindParentForRewrite(k3) failed: %v", err) - } - if k3P != c.AUMHashes["B"] { - t.Errorf("FindParentForRewrite(k3) = %v, want %v", k3P, c.AUMHashes["B"]) - } - - // k2 was added but then removed, so HEAD is an appropriate rewrite point. - k2P, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k2ID}, k1ID) - if err != nil { - t.Fatalf("FindParentForRewrite(k2) failed: %v", err) - } - if k3P != c.AUMHashes["B"] { - t.Errorf("FindParentForRewrite(k2) = %v, want %v", k2P, a.Head()) - } - - // There's no appropriate point where both k2 and k3 are simultaneously not trusted, - // so the best rewrite point is the genesis AUM. - doubleP, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k2ID, k3ID}, k1ID) - if err != nil { - t.Fatalf("FindParentForRewrite({k2, k3}) failed: %v", err) - } - if doubleP != a.oldestAncestor.Hash() { - t.Errorf("FindParentForRewrite({k2, k3}) = %v, want %v", doubleP, a.oldestAncestor.Hash()) - } -} - -func TestMakeRetroactiveRevocation(t *testing.T) { - pub, _ := testingKey25519(t, 1) - k1 := Key{Kind: Key25519, Public: pub, Votes: 1} - - pub2, _ := testingKey25519(t, 2) - k2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - pub3, _ := testingKey25519(t, 3) - k3 := Key{Kind: Key25519, Public: pub3, Votes: 1} - - c := newTestchain(t, ` - A -> B -> C -> D - A.template = genesis - C.template = add2 - D.template = add3 - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{k1}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optTemplate("add2", AUM{MessageKind: AUMAddKey, Key: &k2}), - optTemplate("add3", AUM{MessageKind: AUMAddKey, Key: &k3})) - - a, err := Open(c.Chonk()) - if err != nil { - t.Fatal(err) - } - - // k2 was added by C, so a forking revocation should: - // - have B as a parent - // - trust the remaining keys at the time, k1 & k3. - k1ID, _ := k1.ID() - k2ID, _ := k2.ID() - k3ID, _ := k3.ID() - forkingAUM, err := a.MakeRetroactiveRevocation(c.Chonk(), []tkatype.KeyID{k2ID}, k1ID, AUMHash{}) - if err != nil { - t.Fatalf("MakeRetroactiveRevocation(k2) failed: %v", err) - } - if bHash := c.AUMHashes["B"]; !bytes.Equal(forkingAUM.PrevAUMHash, bHash[:]) { - t.Errorf("forking AUM has parent %v, want %v", forkingAUM.PrevAUMHash, bHash[:]) - } - if _, err := forkingAUM.State.GetKey(k1ID); err != nil { - t.Error("Forked state did not trust k1") - } - if _, err := forkingAUM.State.GetKey(k3ID); err != nil { - t.Error("Forked state did not trust k3") - } - if _, err := forkingAUM.State.GetKey(k2ID); err == nil { - t.Error("Forked state trusted removed-key k2") - } - - // Test that removing all trusted keys results in an error. - _, err = a.MakeRetroactiveRevocation(c.Chonk(), []tkatype.KeyID{k1ID, k2ID, k3ID}, k1ID, AUMHash{}) - if wantErr := "cannot revoke all trusted keys"; err == nil || err.Error() != wantErr { - t.Fatalf("MakeRetroactiveRevocation({k1, k2, k3}) returned %v, expected %q", err, wantErr) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "bytes" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/types/key" + "tailscale.com/types/tkatype" +) + +func TestComputeChainCandidates(t *testing.T) { + c := newTestchain(t, ` + G1 -> I1 -> I2 -> I3 -> L2 + | -> L1 | -> L3 + + G2 -> L4 + + // We tweak these AUMs so they are different hashes. + G2.hashSeed = 2 + L1.hashSeed = 2 + L3.hashSeed = 2 + L4.hashSeed = 3 + `) + // Should result in 4 chains: + // G1->L1, G1->L2, G1->L3, G2->L4 + + i1H := c.AUMHashes["I1"] + got, err := computeChainCandidates(c.Chonk(), &i1H, 50) + if err != nil { + t.Fatalf("computeChainCandidates() failed: %v", err) + } + + want := []chain{ + {Oldest: c.AUMs["G2"], Head: c.AUMs["L4"]}, + {Oldest: c.AUMs["G1"], Head: c.AUMs["L3"], chainsThroughActive: true}, + {Oldest: c.AUMs["G1"], Head: c.AUMs["L1"], chainsThroughActive: true}, + {Oldest: c.AUMs["G1"], Head: c.AUMs["L2"], chainsThroughActive: true}, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(chain{})); diff != "" { + t.Errorf("chains differ (-want, +got):\n%s", diff) + } +} + +func TestForkResolutionHash(t *testing.T) { + c := newTestchain(t, ` + G1 -> L1 + | -> L2 + + // tweak hashes so L1 & L2 are not identical + L1.hashSeed = 2 + L2.hashSeed = 3 + `) + + got, err := computeActiveChain(c.Chonk(), nil, 50) + if err != nil { + t.Fatalf("computeActiveChain() failed: %v", err) + } + + // The fork with the lowest AUM hash should have been chosen. + l1H := c.AUMHashes["L1"] + l2H := c.AUMHashes["L2"] + want := l1H + if bytes.Compare(l2H[:], l1H[:]) < 0 { + want = l2H + } + + if got := got.Head.Hash(); got != want { + t.Errorf("head was %x, want %x", got, want) + } +} + +func TestForkResolutionSigWeight(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> L1 + | -> L2 + + G1.template = addKey + L1.hashSeed = 11 + L2.signedWith = key + `, + optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}), + optKey("key", key, priv)) + + l1H := c.AUMHashes["L1"] + l2H := c.AUMHashes["L2"] + if bytes.Compare(l2H[:], l1H[:]) < 0 { + t.Fatal("failed assert: h(l1) > h(l2)\nTweak hashSeed till this passes") + } + + got, err := computeActiveChain(c.Chonk(), nil, 50) + if err != nil { + t.Fatalf("computeActiveChain() failed: %v", err) + } + + // Based on the hash, l1H should be chosen. + // But based on the signature weight (which has higher + // precedence), it should be l2H + want := l2H + if got := got.Head.Hash(); got != want { + t.Errorf("head was %x, want %x", got, want) + } +} + +func TestForkResolutionMessageType(t *testing.T) { + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> L1 + | -> L2 + | -> L3 + + G1.template = addKey + L1.hashSeed = 11 + L2.template = removeKey + L3.hashSeed = 18 + `, + optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}), + optTemplate("removeKey", AUM{MessageKind: AUMRemoveKey, KeyID: key.MustID()})) + + l1H := c.AUMHashes["L1"] + l2H := c.AUMHashes["L2"] + l3H := c.AUMHashes["L3"] + if bytes.Compare(l2H[:], l1H[:]) < 0 { + t.Fatal("failed assert: h(l1) > h(l2)\nTweak hashSeed till this passes") + } + if bytes.Compare(l2H[:], l3H[:]) < 0 { + t.Fatal("failed assert: h(l3) > h(l2)\nTweak hashSeed till this passes") + } + + got, err := computeActiveChain(c.Chonk(), nil, 50) + if err != nil { + t.Fatalf("computeActiveChain() failed: %v", err) + } + + // Based on the hash, L1 or L3 should be chosen. + // But based on the preference for AUMRemoveKey messages, + // it should be L2. + want := l2H + if got := got.Head.Hash(); got != want { + t.Errorf("head was %x, want %x", got, want) + } +} + +func TestComputeStateAt(t *testing.T) { + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> I1 -> I2 + I1.template = addKey + `, + optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key})) + + // G1 is before the key, so there shouldn't be a key there. + state, err := computeStateAt(c.Chonk(), 500, c.AUMHashes["G1"]) + if err != nil { + t.Fatalf("computeStateAt(G1) failed: %v", err) + } + if _, err := state.GetKey(key.MustID()); err != ErrNoSuchKey { + t.Errorf("expected key to be missing: err = %v", err) + } + if *state.LastAUMHash != c.AUMHashes["G1"] { + t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, c.AUMHashes["G1"]) + } + + // I1 & I2 are after the key, so the computed state should contain + // the key. + for _, wantHash := range []AUMHash{c.AUMHashes["I1"], c.AUMHashes["I2"]} { + state, err = computeStateAt(c.Chonk(), 500, wantHash) + if err != nil { + t.Fatalf("computeStateAt(%X) failed: %v", wantHash, err) + } + if *state.LastAUMHash != wantHash { + t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, wantHash) + } + if _, err := state.GetKey(key.MustID()); err != nil { + t.Errorf("expected key to be present at state: err = %v", err) + } + } +} + +// fakeAUM generates an AUM structure based on the template. +// If parent is provided, PrevAUMHash is set to that value. +// +// If template is an AUM, the returned AUM is based on that. +// If template is an int, a NOOP AUM is returned, and the +// provided int can be used to tweak the resulting hash (needed +// for tests you want one AUM to be 'lower' than another, so that +// that chain is taken based on fork resolution rules). +func fakeAUM(t *testing.T, template any, parent *AUMHash) (AUM, AUMHash) { + if seed, ok := template.(int); ok { + a := AUM{MessageKind: AUMNoOp, KeyID: []byte{byte(seed)}} + if parent != nil { + a.PrevAUMHash = (*parent)[:] + } + h := a.Hash() + return a, h + } + + if a, ok := template.(AUM); ok { + if parent != nil { + a.PrevAUMHash = (*parent)[:] + } + h := a.Hash() + return a, h + } + + panic("template must be an int or an AUM") +} + +func TestOpenAuthority(t *testing.T) { + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + // /- L1 + // G1 - I1 - I2 - I3 -L2 + // \-L3 + // G2 - L4 + // + // We set the previous-known ancestor to G1, so the + // ancestor to start from should be G1. + g1, g1H := fakeAUM(t, AUM{MessageKind: AUMAddKey, Key: &key}, nil) + i1, i1H := fakeAUM(t, 2, &g1H) // AUM{MessageKind: AUMAddKey, Key: &key2} + l1, l1H := fakeAUM(t, 13, &i1H) + + i2, i2H := fakeAUM(t, 2, &i1H) + i3, i3H := fakeAUM(t, 5, &i2H) + l2, l2H := fakeAUM(t, AUM{MessageKind: AUMNoOp, KeyID: []byte{7}, Signatures: []tkatype.Signature{{KeyID: key.MustID()}}}, &i3H) + l3, l3H := fakeAUM(t, 4, &i3H) + + g2, g2H := fakeAUM(t, 8, nil) + l4, _ := fakeAUM(t, 9, &g2H) + + // We make sure that I2 has a lower hash than L1, so + // it should take that path rather than L1. + if bytes.Compare(l1H[:], i2H[:]) < 0 { + t.Fatal("failed assert: h(i2) > h(l1)\nTweak parameters to fakeAUM till this passes") + } + // We make sure L2 has a signature with key, so it should + // take that path over L3. We assert that the L3 hash + // is less than L2 so the test will fail if the signature + // preference logic is broken. + if bytes.Compare(l2H[:], l3H[:]) < 0 { + t.Fatal("failed assert: h(l3) > h(l2)\nTweak parameters to fakeAUM till this passes") + } + + // Construct the state of durable storage. + chonk := &Mem{} + err := chonk.CommitVerifiedAUMs([]AUM{g1, i1, l1, i2, i3, l2, l3, g2, l4}) + if err != nil { + t.Fatal(err) + } + chonk.SetLastActiveAncestor(i1H) + + a, err := Open(chonk) + if err != nil { + t.Fatalf("New() failed: %v", err) + } + // Should include the key added in G1 + if _, err := a.state.GetKey(key.MustID()); err != nil { + t.Errorf("missing G1 key: %v", err) + } + // The head of the chain should be L2. + if a.Head() != l2H { + t.Errorf("head was %x, want %x", a.state.LastAUMHash, l2H) + } +} + +func TestOpenAuthority_EmptyErrors(t *testing.T) { + _, err := Open(&Mem{}) + if err == nil { + t.Error("Expected an error initializing an empty authority, got nil") + } +} + +func TestAuthorityHead(t *testing.T) { + c := newTestchain(t, ` + G1 -> L1 + | -> L2 + + L1.hashSeed = 2 + `) + + a, _ := Open(c.Chonk()) + if got, want := a.head.Hash(), a.Head(); got != want { + t.Errorf("Hash() returned %x, want %x", got, want) + } +} + +func TestAuthorityValidDisablement(t *testing.T) { + pub, _ := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + c := newTestchain(t, ` + G1 -> L1 + + G1.template = genesis + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + ) + + a, _ := Open(c.Chonk()) + if valid := a.ValidDisablement([]byte{1, 2, 3}); !valid { + t.Error("ValidDisablement() returned false, want true") + } +} + +func TestCreateBootstrapAuthority(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + a1, genesisAUM, err := Create(&Mem{}, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + a2, err := Bootstrap(&Mem{}, genesisAUM) + if err != nil { + t.Fatalf("Bootstrap() failed: %v", err) + } + + if a1.Head() != a2.Head() { + t.Fatal("created and bootstrapped authority differ") + } + + // Both authorities should trust the key laid down in the genesis state. + if !a1.KeyTrusted(key.MustID()) { + t.Error("a1 did not trust genesis key") + } + if !a2.KeyTrusted(key.MustID()) { + t.Error("a2 did not trust genesis key") + } +} + +func TestAuthorityInformNonLinear(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> L1 + | -> L2 -> L3 + | -> L4 -> L5 + + G1.template = genesis + L1.hashSeed = 3 + L2.hashSeed = 2 + L4.hashSeed = 2 + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optKey("key", key, priv), + optSignAllUsing("key")) + + storage := &Mem{} + a, err := Bootstrap(storage, c.AUMs["G1"]) + if err != nil { + t.Fatalf("Bootstrap() failed: %v", err) + } + + // L2 does not chain from L1, disabling the isHeadChain optimization + // and forcing Inform() to take the slow path. + informAUMs := []AUM{c.AUMs["L1"], c.AUMs["L2"], c.AUMs["L3"], c.AUMs["L4"], c.AUMs["L5"]} + + if err := a.Inform(storage, informAUMs); err != nil { + t.Fatalf("Inform() failed: %v", err) + } + for i, update := range informAUMs { + stored, err := storage.AUM(update.Hash()) + if err != nil { + t.Errorf("reading stored update %d: %v", i, err) + continue + } + if diff := cmp.Diff(update, stored); diff != "" { + t.Errorf("update %d differs (-want, +got):\n%s", i, diff) + } + } + + if a.Head() != c.AUMHashes["L3"] { + t.Fatal("authority did not converge to correct AUM") + } +} + +func TestAuthorityInformLinear(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G1 -> L1 -> L2 -> L3 + + G1.template = genesis + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optKey("key", key, priv), + optSignAllUsing("key")) + + storage := &Mem{} + a, err := Bootstrap(storage, c.AUMs["G1"]) + if err != nil { + t.Fatalf("Bootstrap() failed: %v", err) + } + + informAUMs := []AUM{c.AUMs["L1"], c.AUMs["L2"], c.AUMs["L3"]} + + if err := a.Inform(storage, informAUMs); err != nil { + t.Fatalf("Inform() failed: %v", err) + } + for i, update := range informAUMs { + stored, err := storage.AUM(update.Hash()) + if err != nil { + t.Errorf("reading stored update %d: %v", i, err) + continue + } + if diff := cmp.Diff(update, stored); diff != "" { + t.Errorf("update %d differs (-want, +got):\n%s", i, diff) + } + } + + if a.Head() != c.AUMHashes["L3"] { + t.Fatal("authority did not converge to correct AUM") + } +} + +func TestInteropWithNLKey(t *testing.T) { + priv1 := key.NewNLPrivate() + pub1 := priv1.Public() + pub2 := key.NewNLPrivate().Public() + pub3 := key.NewNLPrivate().Public() + + a, _, err := Create(&Mem{}, State{ + Keys: []Key{ + { + Kind: Key25519, + Votes: 1, + Public: pub1.KeyID(), + }, + { + Kind: Key25519, + Votes: 1, + Public: pub2.KeyID(), + }, + }, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, priv1) + if err != nil { + t.Errorf("tka.Create: %v", err) + return + } + + if !a.KeyTrusted(pub1.KeyID()) { + t.Error("pub1 want trusted, got untrusted") + } + if !a.KeyTrusted(pub2.KeyID()) { + t.Error("pub2 want trusted, got untrusted") + } + if a.KeyTrusted(pub3.KeyID()) { + t.Error("pub3 want untrusted, got trusted") + } +} + +func TestAuthorityCompact(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + c := newTestchain(t, ` + G -> A -> B -> C -> D -> E + + G.template = genesis + C.template = checkpoint2 + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optTemplate("checkpoint2", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optKey("key", key, priv), + optSignAllUsing("key")) + + storage := &FS{base: t.TempDir()} + a, err := Bootstrap(storage, c.AUMs["G"]) + if err != nil { + t.Fatalf("Bootstrap() failed: %v", err) + } + a.Inform(storage, []AUM{c.AUMs["A"], c.AUMs["B"], c.AUMs["C"], c.AUMs["D"], c.AUMs["E"]}) + + // Should compact down to C -> D -> E + if err := a.Compact(storage, CompactionOptions{MinChain: 2, MinAge: 1}); err != nil { + t.Fatal(err) + } + if a.oldestAncestor.Hash() != c.AUMHashes["C"] { + t.Errorf("ancestor = %v, want %v", a.oldestAncestor.Hash(), c.AUMHashes["C"]) + } + + // Make sure the stored authority is still openable and resolves to the same state. + stored, err := Open(storage) + if err != nil { + t.Fatalf("Failed to open stored authority: %v", err) + } + if stored.Head() != a.Head() { + t.Errorf("Stored authority head differs: head = %v, want %v", stored.Head(), a.Head()) + } + t.Logf("original ancestor = %v", c.AUMHashes["G"]) + if anc, _ := storage.LastActiveAncestor(); *anc != c.AUMHashes["C"] { + t.Errorf("ancestor = %v, want %v", anc, c.AUMHashes["C"]) + } +} + +func TestFindParentForRewrite(t *testing.T) { + pub, _ := testingKey25519(t, 1) + k1 := Key{Kind: Key25519, Public: pub, Votes: 1} + + pub2, _ := testingKey25519(t, 2) + k2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + k2ID, _ := k2.ID() + pub3, _ := testingKey25519(t, 3) + k3 := Key{Kind: Key25519, Public: pub3, Votes: 1} + + c := newTestchain(t, ` + A -> B -> C -> D -> E + A.template = genesis + B.template = add2 + C.template = add3 + D.template = remove2 + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{k1}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optTemplate("add2", AUM{MessageKind: AUMAddKey, Key: &k2}), + optTemplate("add3", AUM{MessageKind: AUMAddKey, Key: &k3}), + optTemplate("remove2", AUM{MessageKind: AUMRemoveKey, KeyID: k2ID})) + + a, err := Open(c.Chonk()) + if err != nil { + t.Fatal(err) + } + + // k1 was trusted at genesis, so there's no better rewrite parent + // than the genesis. + k1ID, _ := k1.ID() + k1P, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k1ID}, k1ID) + if err != nil { + t.Fatalf("FindParentForRewrite(k1) failed: %v", err) + } + if k1P != a.oldestAncestor.Hash() { + t.Errorf("FindParentForRewrite(k1) = %v, want %v", k1P, a.oldestAncestor.Hash()) + } + + // k3 was trusted at C, so B would be an ideal rewrite point. + k3ID, _ := k3.ID() + k3P, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k3ID}, k1ID) + if err != nil { + t.Fatalf("FindParentForRewrite(k3) failed: %v", err) + } + if k3P != c.AUMHashes["B"] { + t.Errorf("FindParentForRewrite(k3) = %v, want %v", k3P, c.AUMHashes["B"]) + } + + // k2 was added but then removed, so HEAD is an appropriate rewrite point. + k2P, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k2ID}, k1ID) + if err != nil { + t.Fatalf("FindParentForRewrite(k2) failed: %v", err) + } + if k3P != c.AUMHashes["B"] { + t.Errorf("FindParentForRewrite(k2) = %v, want %v", k2P, a.Head()) + } + + // There's no appropriate point where both k2 and k3 are simultaneously not trusted, + // so the best rewrite point is the genesis AUM. + doubleP, err := a.findParentForRewrite(c.Chonk(), []tkatype.KeyID{k2ID, k3ID}, k1ID) + if err != nil { + t.Fatalf("FindParentForRewrite({k2, k3}) failed: %v", err) + } + if doubleP != a.oldestAncestor.Hash() { + t.Errorf("FindParentForRewrite({k2, k3}) = %v, want %v", doubleP, a.oldestAncestor.Hash()) + } +} + +func TestMakeRetroactiveRevocation(t *testing.T) { + pub, _ := testingKey25519(t, 1) + k1 := Key{Kind: Key25519, Public: pub, Votes: 1} + + pub2, _ := testingKey25519(t, 2) + k2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + pub3, _ := testingKey25519(t, 3) + k3 := Key{Kind: Key25519, Public: pub3, Votes: 1} + + c := newTestchain(t, ` + A -> B -> C -> D + A.template = genesis + C.template = add2 + D.template = add3 + `, + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{k1}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}), + optTemplate("add2", AUM{MessageKind: AUMAddKey, Key: &k2}), + optTemplate("add3", AUM{MessageKind: AUMAddKey, Key: &k3})) + + a, err := Open(c.Chonk()) + if err != nil { + t.Fatal(err) + } + + // k2 was added by C, so a forking revocation should: + // - have B as a parent + // - trust the remaining keys at the time, k1 & k3. + k1ID, _ := k1.ID() + k2ID, _ := k2.ID() + k3ID, _ := k3.ID() + forkingAUM, err := a.MakeRetroactiveRevocation(c.Chonk(), []tkatype.KeyID{k2ID}, k1ID, AUMHash{}) + if err != nil { + t.Fatalf("MakeRetroactiveRevocation(k2) failed: %v", err) + } + if bHash := c.AUMHashes["B"]; !bytes.Equal(forkingAUM.PrevAUMHash, bHash[:]) { + t.Errorf("forking AUM has parent %v, want %v", forkingAUM.PrevAUMHash, bHash[:]) + } + if _, err := forkingAUM.State.GetKey(k1ID); err != nil { + t.Error("Forked state did not trust k1") + } + if _, err := forkingAUM.State.GetKey(k3ID); err != nil { + t.Error("Forked state did not trust k3") + } + if _, err := forkingAUM.State.GetKey(k2ID); err == nil { + t.Error("Forked state trusted removed-key k2") + } + + // Test that removing all trusted keys results in an error. + _, err = a.MakeRetroactiveRevocation(c.Chonk(), []tkatype.KeyID{k1ID, k2ID, k3ID}, k1ID, AUMHash{}) + if wantErr := "cannot revoke all trusted keys"; err == nil || err.Error() != wantErr { + t.Fatalf("MakeRetroactiveRevocation({k1, k2, k3}) returned %v, expected %q", err, wantErr) + } +} diff --git a/tool/binaryen.rev b/tool/binaryen.rev index e0d03ab88..58c9bdf9d 100644 --- a/tool/binaryen.rev +++ b/tool/binaryen.rev @@ -1 +1 @@ -111 +111 diff --git a/tool/go b/tool/go index 3c99f3e2f..1c53683d5 100755 --- a/tool/go +++ b/tool/go @@ -1,7 +1,7 @@ -#!/bin/sh -# -# This script acts like the "go" command, but uses Tailscale's -# currently-desired version from https://github.com/tailscale/go, -# downloading it first if necessary. - -exec "$(dirname "$0")/../tool/gocross/gocross-wrapper.sh" "$@" +#!/bin/sh +# +# This script acts like the "go" command, but uses Tailscale's +# currently-desired version from https://github.com/tailscale/go, +# downloading it first if necessary. + +exec "$(dirname "$0")/../tool/gocross/gocross-wrapper.sh" "$@" diff --git a/tool/gocross/env.go b/tool/gocross/env.go index 249476dc1..9d8a4f1b3 100644 --- a/tool/gocross/env.go +++ b/tool/gocross/env.go @@ -1,131 +1,131 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "fmt" - "os" - "sort" - "strings" -) - -// Environment starts from an initial set of environment variables, and tracks -// mutations to the environment. It can then apply those mutations to the -// environment, or produce debugging output that illustrates the changes it -// would make. -type Environment struct { - init map[string]string - set map[string]string - unset map[string]bool - - setenv func(string, string) error - unsetenv func(string) error -} - -// NewEnvironment returns an Environment initialized from os.Environ. -func NewEnvironment() *Environment { - init := map[string]string{} - for _, env := range os.Environ() { - fs := strings.SplitN(env, "=", 2) - if len(fs) != 2 { - panic("bad environ provided") - } - init[fs[0]] = fs[1] - } - - return newEnvironmentForTest(init, os.Setenv, os.Unsetenv) -} - -func newEnvironmentForTest(init map[string]string, setenv func(string, string) error, unsetenv func(string) error) *Environment { - return &Environment{ - init: init, - set: map[string]string{}, - unset: map[string]bool{}, - setenv: setenv, - unsetenv: unsetenv, - } -} - -// Set sets the environment variable k to v. -func (e *Environment) Set(k, v string) { - e.set[k] = v - delete(e.unset, k) -} - -// Unset removes the environment variable k. -func (e *Environment) Unset(k string) { - delete(e.set, k) - e.unset[k] = true -} - -// IsSet reports whether the environment variable k is set. -func (e *Environment) IsSet(k string) bool { - if e.unset[k] { - return false - } - if _, ok := e.init[k]; ok { - return true - } - if _, ok := e.set[k]; ok { - return true - } - return false -} - -// Get returns the value of the environment variable k, or defaultVal if it is -// not set. -func (e *Environment) Get(k, defaultVal string) string { - if e.unset[k] { - return defaultVal - } - if v, ok := e.set[k]; ok { - return v - } - if v, ok := e.init[k]; ok { - return v - } - return defaultVal -} - -// Apply applies all pending mutations to the environment. -func (e *Environment) Apply() error { - for k, v := range e.set { - if err := e.setenv(k, v); err != nil { - return fmt.Errorf("setting %q: %v", k, err) - } - e.init[k] = v - delete(e.set, k) - } - for k := range e.unset { - if err := e.unsetenv(k); err != nil { - return fmt.Errorf("unsetting %q: %v", k, err) - } - delete(e.init, k) - delete(e.unset, k) - } - return nil -} - -// Diff returns a string describing the pending mutations to the environment. -func (e *Environment) Diff() string { - lines := make([]string, 0, len(e.set)+len(e.unset)) - for k, v := range e.set { - old, ok := e.init[k] - if ok { - lines = append(lines, fmt.Sprintf("%s=%s (was %s)", k, v, old)) - } else { - lines = append(lines, fmt.Sprintf("%s=%s (was )", k, v)) - } - } - for k := range e.unset { - old, ok := e.init[k] - if ok { - lines = append(lines, fmt.Sprintf("%s= (was %s)", k, old)) - } else { - lines = append(lines, fmt.Sprintf("%s= (was )", k)) - } - } - sort.Strings(lines) - return strings.Join(lines, "\n") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "fmt" + "os" + "sort" + "strings" +) + +// Environment starts from an initial set of environment variables, and tracks +// mutations to the environment. It can then apply those mutations to the +// environment, or produce debugging output that illustrates the changes it +// would make. +type Environment struct { + init map[string]string + set map[string]string + unset map[string]bool + + setenv func(string, string) error + unsetenv func(string) error +} + +// NewEnvironment returns an Environment initialized from os.Environ. +func NewEnvironment() *Environment { + init := map[string]string{} + for _, env := range os.Environ() { + fs := strings.SplitN(env, "=", 2) + if len(fs) != 2 { + panic("bad environ provided") + } + init[fs[0]] = fs[1] + } + + return newEnvironmentForTest(init, os.Setenv, os.Unsetenv) +} + +func newEnvironmentForTest(init map[string]string, setenv func(string, string) error, unsetenv func(string) error) *Environment { + return &Environment{ + init: init, + set: map[string]string{}, + unset: map[string]bool{}, + setenv: setenv, + unsetenv: unsetenv, + } +} + +// Set sets the environment variable k to v. +func (e *Environment) Set(k, v string) { + e.set[k] = v + delete(e.unset, k) +} + +// Unset removes the environment variable k. +func (e *Environment) Unset(k string) { + delete(e.set, k) + e.unset[k] = true +} + +// IsSet reports whether the environment variable k is set. +func (e *Environment) IsSet(k string) bool { + if e.unset[k] { + return false + } + if _, ok := e.init[k]; ok { + return true + } + if _, ok := e.set[k]; ok { + return true + } + return false +} + +// Get returns the value of the environment variable k, or defaultVal if it is +// not set. +func (e *Environment) Get(k, defaultVal string) string { + if e.unset[k] { + return defaultVal + } + if v, ok := e.set[k]; ok { + return v + } + if v, ok := e.init[k]; ok { + return v + } + return defaultVal +} + +// Apply applies all pending mutations to the environment. +func (e *Environment) Apply() error { + for k, v := range e.set { + if err := e.setenv(k, v); err != nil { + return fmt.Errorf("setting %q: %v", k, err) + } + e.init[k] = v + delete(e.set, k) + } + for k := range e.unset { + if err := e.unsetenv(k); err != nil { + return fmt.Errorf("unsetting %q: %v", k, err) + } + delete(e.init, k) + delete(e.unset, k) + } + return nil +} + +// Diff returns a string describing the pending mutations to the environment. +func (e *Environment) Diff() string { + lines := make([]string, 0, len(e.set)+len(e.unset)) + for k, v := range e.set { + old, ok := e.init[k] + if ok { + lines = append(lines, fmt.Sprintf("%s=%s (was %s)", k, v, old)) + } else { + lines = append(lines, fmt.Sprintf("%s=%s (was )", k, v)) + } + } + for k := range e.unset { + old, ok := e.init[k] + if ok { + lines = append(lines, fmt.Sprintf("%s= (was %s)", k, old)) + } else { + lines = append(lines, fmt.Sprintf("%s= (was )", k)) + } + } + sort.Strings(lines) + return strings.Join(lines, "\n") +} diff --git a/tool/gocross/env_test.go b/tool/gocross/env_test.go index 9a797530d..001487bb8 100644 --- a/tool/gocross/env_test.go +++ b/tool/gocross/env_test.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestEnv(t *testing.T) { - - var ( - init = map[string]string{ - "FOO": "bar", - } - - wasSet = map[string]string{} - wasUnset = map[string]bool{} - - setenv = func(k, v string) error { - wasSet[k] = v - return nil - } - unsetenv = func(k string) error { - wasUnset[k] = true - return nil - } - ) - - env := newEnvironmentForTest(init, setenv, unsetenv) - - if got, want := env.Get("FOO", ""), "bar"; got != want { - t.Errorf(`env.Get("FOO") = %q, want %q`, got, want) - } - if got, want := env.IsSet("FOO"), true; got != want { - t.Errorf(`env.IsSet("FOO") = %v, want %v`, got, want) - } - - if got, want := env.Get("BAR", "defaultVal"), "defaultVal"; got != want { - t.Errorf(`env.Get("BAR") = %q, want %q`, got, want) - } - if got, want := env.IsSet("BAR"), false; got != want { - t.Errorf(`env.IsSet("BAR") = %v, want %v`, got, want) - } - - env.Set("BAR", "quux") - if got, want := env.Get("BAR", ""), "quux"; got != want { - t.Errorf(`env.Get("BAR") = %q, want %q`, got, want) - } - if got, want := env.IsSet("BAR"), true; got != want { - t.Errorf(`env.IsSet("BAR") = %v, want %v`, got, want) - } - diff := "BAR=quux (was )" - if got := env.Diff(); got != diff { - t.Errorf("env.Diff() = %q, want %q", got, diff) - } - - env.Set("FOO", "foo2") - if got, want := env.Get("FOO", ""), "foo2"; got != want { - t.Errorf(`env.Get("FOO") = %q, want %q`, got, want) - } - if got, want := env.IsSet("FOO"), true; got != want { - t.Errorf(`env.IsSet("FOO") = %v, want %v`, got, want) - } - diff = `BAR=quux (was ) -FOO=foo2 (was bar)` - if got := env.Diff(); got != diff { - t.Errorf("env.Diff() = %q, want %q", got, diff) - } - - env.Unset("FOO") - if got, want := env.Get("FOO", "default"), "default"; got != want { - t.Errorf(`env.Get("FOO") = %q, want %q`, got, want) - } - if got, want := env.IsSet("FOO"), false; got != want { - t.Errorf(`env.IsSet("FOO") = %v, want %v`, got, want) - } - diff = `BAR=quux (was ) -FOO= (was bar)` - if got := env.Diff(); got != diff { - t.Errorf("env.Diff() = %q, want %q", got, diff) - } - - if err := env.Apply(); err != nil { - t.Fatalf("env.Apply() failed: %v", err) - } - - wantSet := map[string]string{"BAR": "quux"} - wantUnset := map[string]bool{"FOO": true} - - if diff := cmp.Diff(wasSet, wantSet); diff != "" { - t.Errorf("env.Apply didn't set as expected (-got+want):\n%s", diff) - } - if diff := cmp.Diff(wasUnset, wantUnset); diff != "" { - t.Errorf("env.Apply didn't unset as expected (-got+want):\n%s", diff) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestEnv(t *testing.T) { + + var ( + init = map[string]string{ + "FOO": "bar", + } + + wasSet = map[string]string{} + wasUnset = map[string]bool{} + + setenv = func(k, v string) error { + wasSet[k] = v + return nil + } + unsetenv = func(k string) error { + wasUnset[k] = true + return nil + } + ) + + env := newEnvironmentForTest(init, setenv, unsetenv) + + if got, want := env.Get("FOO", ""), "bar"; got != want { + t.Errorf(`env.Get("FOO") = %q, want %q`, got, want) + } + if got, want := env.IsSet("FOO"), true; got != want { + t.Errorf(`env.IsSet("FOO") = %v, want %v`, got, want) + } + + if got, want := env.Get("BAR", "defaultVal"), "defaultVal"; got != want { + t.Errorf(`env.Get("BAR") = %q, want %q`, got, want) + } + if got, want := env.IsSet("BAR"), false; got != want { + t.Errorf(`env.IsSet("BAR") = %v, want %v`, got, want) + } + + env.Set("BAR", "quux") + if got, want := env.Get("BAR", ""), "quux"; got != want { + t.Errorf(`env.Get("BAR") = %q, want %q`, got, want) + } + if got, want := env.IsSet("BAR"), true; got != want { + t.Errorf(`env.IsSet("BAR") = %v, want %v`, got, want) + } + diff := "BAR=quux (was )" + if got := env.Diff(); got != diff { + t.Errorf("env.Diff() = %q, want %q", got, diff) + } + + env.Set("FOO", "foo2") + if got, want := env.Get("FOO", ""), "foo2"; got != want { + t.Errorf(`env.Get("FOO") = %q, want %q`, got, want) + } + if got, want := env.IsSet("FOO"), true; got != want { + t.Errorf(`env.IsSet("FOO") = %v, want %v`, got, want) + } + diff = `BAR=quux (was ) +FOO=foo2 (was bar)` + if got := env.Diff(); got != diff { + t.Errorf("env.Diff() = %q, want %q", got, diff) + } + + env.Unset("FOO") + if got, want := env.Get("FOO", "default"), "default"; got != want { + t.Errorf(`env.Get("FOO") = %q, want %q`, got, want) + } + if got, want := env.IsSet("FOO"), false; got != want { + t.Errorf(`env.IsSet("FOO") = %v, want %v`, got, want) + } + diff = `BAR=quux (was ) +FOO= (was bar)` + if got := env.Diff(); got != diff { + t.Errorf("env.Diff() = %q, want %q", got, diff) + } + + if err := env.Apply(); err != nil { + t.Fatalf("env.Apply() failed: %v", err) + } + + wantSet := map[string]string{"BAR": "quux"} + wantUnset := map[string]bool{"FOO": true} + + if diff := cmp.Diff(wasSet, wantSet); diff != "" { + t.Errorf("env.Apply didn't set as expected (-got+want):\n%s", diff) + } + if diff := cmp.Diff(wasUnset, wantUnset); diff != "" { + t.Errorf("env.Apply didn't unset as expected (-got+want):\n%s", diff) + } +} diff --git a/tool/gocross/exec_other.go b/tool/gocross/exec_other.go index ec9663df7..8d4df0db3 100644 --- a/tool/gocross/exec_other.go +++ b/tool/gocross/exec_other.go @@ -1,20 +1,20 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !unix - -package main - -import ( - "os" - "os/exec" -) - -func doExec(cmd string, args []string, env []string) error { - c := exec.Command(cmd, args...) - c.Env = env - c.Stdin = os.Stdin - c.Stdout = os.Stdout - c.Stderr = os.Stderr - return c.Run() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !unix + +package main + +import ( + "os" + "os/exec" +) + +func doExec(cmd string, args []string, env []string) error { + c := exec.Command(cmd, args...) + c.Env = env + c.Stdin = os.Stdin + c.Stdout = os.Stdout + c.Stderr = os.Stderr + return c.Run() +} diff --git a/tool/gocross/exec_unix.go b/tool/gocross/exec_unix.go index eeffd5f93..79cbf764a 100644 --- a/tool/gocross/exec_unix.go +++ b/tool/gocross/exec_unix.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build unix - -package main - -import "golang.org/x/sys/unix" - -func doExec(cmd string, args []string, env []string) error { - return unix.Exec(cmd, args, env) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build unix + +package main + +import "golang.org/x/sys/unix" + +func doExec(cmd string, args []string, env []string) error { + return unix.Exec(cmd, args, env) +} diff --git a/tool/helm b/tool/helm index 8cbc2f206..3f9a9dfd5 100755 --- a/tool/helm +++ b/tool/helm @@ -1,69 +1,69 @@ -#!/usr/bin/env bash - -# installs $(cat ./helm.rev) version of helm as $HOME/.cache/tailscale-helm - -set -euo pipefail - -if [[ "${CI:-}" == "true" ]]; then - set -x -fi - -( - if [[ "${CI:-}" == "true" ]]; then - set -x - fi - - repo_root="${BASH_SOURCE%/*}/../" - cd "$repo_root" - - cachedir="$HOME/.cache/tailscale-helm" - tarball="${cachedir}.tar.gz" - - read -r want_rev < "$(dirname "$0")/helm.rev" - - got_rev="" - if [[ -x "${cachedir}/helm" ]]; then - got_rev=$("${cachedir}/helm" version --short) - got_rev="${got_rev#v}" # trim the leading 'v' - got_rev="${got_rev%+*}" # trim the trailing '+" followed by a commit SHA' - - - fi - - if [[ "$want_rev" != "$got_rev" ]]; then - rm -rf "$cachedir" "$tarball" - if [[ -n "${IN_NIX_SHELL:-}" ]]; then - nix_helm="$(which -a helm | grep /nix/store | head -1)" - nix_helm="${nix_helm%/helm}" - nix_helm_rev="${nix_helm##*-}" - if [[ "$nix_helm_rev" != "$want_rev" ]]; then - echo "Wrong helm version in Nix, got $nix_helm_rev want $want_rev" >&2 - exit 1 - fi - ln -sf "$nix_helm" "$cachedir" - else - # works for linux and darwin - # https://github.com/helm/helm/releases - OS=$(uname -s | tr A-Z a-z) - ARCH=$(uname -m) - if [ "$ARCH" = "x86_64" ]; then - ARCH="amd64" - fi - if [ "$ARCH" = "aarch64" ]; then - ARCH="arm64" - fi - mkdir -p "$cachedir" - # When running on GitHub in CI, the below curl sometimes fails with - # INTERNAL_ERROR after finishing the download. The most common cause - # of INTERNAL_ERROR is glitches in intermediate hosts handling of - # HTTP/2 forwarding, so forcing HTTP 1.1 often fixes the issue. See - # https://github.com/tailscale/tailscale/issues/8988 - curl -f -L --http1.1 -o "$tarball" -sSL "https://get.helm.sh/helm-v${want_rev}-${OS}-${ARCH}.tar.gz" - (cd "$cachedir" && tar --strip-components=1 -xf "$tarball") - rm -f "$tarball" - fi - fi -) - -export PATH="$HOME/.cache/tailscale-helm:$PATH" -exec "$HOME/.cache/tailscale-helm/helm" "$@" +#!/usr/bin/env bash + +# installs $(cat ./helm.rev) version of helm as $HOME/.cache/tailscale-helm + +set -euo pipefail + +if [[ "${CI:-}" == "true" ]]; then + set -x +fi + +( + if [[ "${CI:-}" == "true" ]]; then + set -x + fi + + repo_root="${BASH_SOURCE%/*}/../" + cd "$repo_root" + + cachedir="$HOME/.cache/tailscale-helm" + tarball="${cachedir}.tar.gz" + + read -r want_rev < "$(dirname "$0")/helm.rev" + + got_rev="" + if [[ -x "${cachedir}/helm" ]]; then + got_rev=$("${cachedir}/helm" version --short) + got_rev="${got_rev#v}" # trim the leading 'v' + got_rev="${got_rev%+*}" # trim the trailing '+" followed by a commit SHA' + + + fi + + if [[ "$want_rev" != "$got_rev" ]]; then + rm -rf "$cachedir" "$tarball" + if [[ -n "${IN_NIX_SHELL:-}" ]]; then + nix_helm="$(which -a helm | grep /nix/store | head -1)" + nix_helm="${nix_helm%/helm}" + nix_helm_rev="${nix_helm##*-}" + if [[ "$nix_helm_rev" != "$want_rev" ]]; then + echo "Wrong helm version in Nix, got $nix_helm_rev want $want_rev" >&2 + exit 1 + fi + ln -sf "$nix_helm" "$cachedir" + else + # works for linux and darwin + # https://github.com/helm/helm/releases + OS=$(uname -s | tr A-Z a-z) + ARCH=$(uname -m) + if [ "$ARCH" = "x86_64" ]; then + ARCH="amd64" + fi + if [ "$ARCH" = "aarch64" ]; then + ARCH="arm64" + fi + mkdir -p "$cachedir" + # When running on GitHub in CI, the below curl sometimes fails with + # INTERNAL_ERROR after finishing the download. The most common cause + # of INTERNAL_ERROR is glitches in intermediate hosts handling of + # HTTP/2 forwarding, so forcing HTTP 1.1 often fixes the issue. See + # https://github.com/tailscale/tailscale/issues/8988 + curl -f -L --http1.1 -o "$tarball" -sSL "https://get.helm.sh/helm-v${want_rev}-${OS}-${ARCH}.tar.gz" + (cd "$cachedir" && tar --strip-components=1 -xf "$tarball") + rm -f "$tarball" + fi + fi +) + +export PATH="$HOME/.cache/tailscale-helm:$PATH" +exec "$HOME/.cache/tailscale-helm/helm" "$@" diff --git a/tool/helm.rev b/tool/helm.rev index 0d0e48dd0..c10780c62 100644 --- a/tool/helm.rev +++ b/tool/helm.rev @@ -1 +1 @@ -3.13.1 +3.13.1 diff --git a/tool/node b/tool/node index 7e96826f3..310140ae5 100755 --- a/tool/node +++ b/tool/node @@ -1,65 +1,65 @@ -#!/usr/bin/env bash -# Run a command with our local node install, rather than any globally installed -# instance. - -set -euo pipefail - -if [[ "${CI:-}" == "true" ]]; then - set -x -fi - -( - if [[ "${CI:-}" == "true" ]]; then - set -x - fi - - repo_root="${BASH_SOURCE%/*}/../" - cd "$repo_root" - - cachedir="$HOME/.cache/tailscale-node" - tarball="${cachedir}.tar.gz" - - read -r want_rev < "$(dirname "$0")/node.rev" - - got_rev="" - if [[ -x "${cachedir}/bin/node" ]]; then - got_rev=$("${cachedir}/bin/node" --version) - got_rev="${got_rev#v}" # trim the leading 'v' - fi - - if [[ "$want_rev" != "$got_rev" ]]; then - rm -rf "$cachedir" "$tarball" - if [[ -n "${IN_NIX_SHELL:-}" ]]; then - nix_node="$(which -a node | grep /nix/store | head -1)" - nix_node="${nix_node%/bin/node}" - nix_node_rev="${nix_node##*-}" - if [[ "$nix_node_rev" != "$want_rev" ]]; then - echo "Wrong node version in Nix, got $nix_node_rev want $want_rev" >&2 - exit 1 - fi - ln -sf "$nix_node" "$cachedir" - else - # works for "linux" and "darwin" - OS=$(uname -s | tr A-Z a-z) - ARCH=$(uname -m) - if [ "$ARCH" = "x86_64" ]; then - ARCH="x64" - fi - if [ "$ARCH" = "aarch64" ]; then - ARCH="arm64" - fi - mkdir -p "$cachedir" - # When running on GitHub in CI, the below curl sometimes fails with - # INTERNAL_ERROR after finishing the download. The most common cause - # of INTERNAL_ERROR is glitches in intermediate hosts handling of - # HTTP/2 forwarding, so forcing HTTP 1.1 often fixes the issue. See - # https://github.com/tailscale/tailscale/issues/8988 - curl -f -L --http1.1 -o "$tarball" "https://nodejs.org/dist/v${want_rev}/node-v${want_rev}-${OS}-${ARCH}.tar.gz" - (cd "$cachedir" && tar --strip-components=1 -xf "$tarball") - rm -f "$tarball" - fi - fi -) - -export PATH="$HOME/.cache/tailscale-node/bin:$PATH" -exec "$HOME/.cache/tailscale-node/bin/node" "$@" +#!/usr/bin/env bash +# Run a command with our local node install, rather than any globally installed +# instance. + +set -euo pipefail + +if [[ "${CI:-}" == "true" ]]; then + set -x +fi + +( + if [[ "${CI:-}" == "true" ]]; then + set -x + fi + + repo_root="${BASH_SOURCE%/*}/../" + cd "$repo_root" + + cachedir="$HOME/.cache/tailscale-node" + tarball="${cachedir}.tar.gz" + + read -r want_rev < "$(dirname "$0")/node.rev" + + got_rev="" + if [[ -x "${cachedir}/bin/node" ]]; then + got_rev=$("${cachedir}/bin/node" --version) + got_rev="${got_rev#v}" # trim the leading 'v' + fi + + if [[ "$want_rev" != "$got_rev" ]]; then + rm -rf "$cachedir" "$tarball" + if [[ -n "${IN_NIX_SHELL:-}" ]]; then + nix_node="$(which -a node | grep /nix/store | head -1)" + nix_node="${nix_node%/bin/node}" + nix_node_rev="${nix_node##*-}" + if [[ "$nix_node_rev" != "$want_rev" ]]; then + echo "Wrong node version in Nix, got $nix_node_rev want $want_rev" >&2 + exit 1 + fi + ln -sf "$nix_node" "$cachedir" + else + # works for "linux" and "darwin" + OS=$(uname -s | tr A-Z a-z) + ARCH=$(uname -m) + if [ "$ARCH" = "x86_64" ]; then + ARCH="x64" + fi + if [ "$ARCH" = "aarch64" ]; then + ARCH="arm64" + fi + mkdir -p "$cachedir" + # When running on GitHub in CI, the below curl sometimes fails with + # INTERNAL_ERROR after finishing the download. The most common cause + # of INTERNAL_ERROR is glitches in intermediate hosts handling of + # HTTP/2 forwarding, so forcing HTTP 1.1 often fixes the issue. See + # https://github.com/tailscale/tailscale/issues/8988 + curl -f -L --http1.1 -o "$tarball" "https://nodejs.org/dist/v${want_rev}/node-v${want_rev}-${OS}-${ARCH}.tar.gz" + (cd "$cachedir" && tar --strip-components=1 -xf "$tarball") + rm -f "$tarball" + fi + fi +) + +export PATH="$HOME/.cache/tailscale-node/bin:$PATH" +exec "$HOME/.cache/tailscale-node/bin/node" "$@" diff --git a/tool/wasm-opt b/tool/wasm-opt index 88d332f0b..08f3e5bfb 100755 --- a/tool/wasm-opt +++ b/tool/wasm-opt @@ -1,74 +1,74 @@ -#!/bin/sh -# -# This script acts like the "wasm-opt" command from the Binaryen toolchain, but -# uses Tailscale's currently-desired version, downloading it first if necessary. - -set -eu - -BINARYEN_DIR="$HOME/.cache/tailscale-binaryen" -read -r BINARYEN_REV < "$(dirname "$0")/binaryen.rev" -# This works for Linux and Darwin, which is sufficient -# (we do not build for other targets). -OS=$(uname -s | tr A-Z a-z) -if [ "$OS" = "darwin" ]; then - # Binaryen uses the name "macos". - OS="macos" -fi -ARCH="$(uname -m)" -if [ "$ARCH" = "aarch64" ]; then - # Binaryen uses the name "arm64". - ARCH="arm64" -fi - -install_binaryen() { - BINARYEN_URL="https://github.com/WebAssembly/binaryen/releases/download/version_${BINARYEN_REV}/binaryen-version_${BINARYEN_REV}-${ARCH}-${OS}.tar.gz" - install_tool "wasm-opt" $BINARYEN_REV $BINARYEN_DIR $BINARYEN_URL -} - -install_tool() { - TOOL=$1 - REV=$2 - TOOLCHAIN=$3 - URL=$4 - - archive="$TOOLCHAIN-$REV.tar.gz" - mark="$TOOLCHAIN.extracted" - extracted= - [ ! -e "$mark" ] || read -r extracted junk <$mark - - if [ "$extracted" = "$REV" ] && [ -e "$TOOLCHAIN/bin/$TOOL" ]; then - # Already extracted, continue silently - return 0 - fi - echo "" - - rm -f "$archive.new" "$TOOLCHAIN.extracted" - if [ ! -e "$archive" ]; then - log "Need to download $TOOL '$REV' from $URL." - curl -f -L -o "$archive.new" $URL - rm -f "$archive" - mv "$archive.new" "$archive" - fi - - log "Extracting $TOOL '$REV' into '$TOOLCHAIN'." >&2 - rm -rf "$TOOLCHAIN" - mkdir -p "$TOOLCHAIN" - (cd "$TOOLCHAIN" && tar --strip-components=1 -xf "$archive") - echo "$REV" >$mark -} - -log() { - echo "$@" >&2 -} - -if [ "${BINARYEN_DIR}" = "SKIP" ] || - [ "${OS}" != "macos" -a "${OS}" != "linux" ] || - [ "${ARCH}" != "x86_64" -a "${ARCH}" != "arm64" ]; then - log "Unsupported OS (${OS}) and architecture (${ARCH}) combination." - log "Using existing wasm-opt (`which wasm-opt`)." - exec wasm-opt "$@" -fi - -install_binaryen - -"$BINARYEN_DIR/bin/wasm-opt" "$@" +#!/bin/sh +# +# This script acts like the "wasm-opt" command from the Binaryen toolchain, but +# uses Tailscale's currently-desired version, downloading it first if necessary. + +set -eu + +BINARYEN_DIR="$HOME/.cache/tailscale-binaryen" +read -r BINARYEN_REV < "$(dirname "$0")/binaryen.rev" +# This works for Linux and Darwin, which is sufficient +# (we do not build for other targets). +OS=$(uname -s | tr A-Z a-z) +if [ "$OS" = "darwin" ]; then + # Binaryen uses the name "macos". + OS="macos" +fi +ARCH="$(uname -m)" +if [ "$ARCH" = "aarch64" ]; then + # Binaryen uses the name "arm64". + ARCH="arm64" +fi + +install_binaryen() { + BINARYEN_URL="https://github.com/WebAssembly/binaryen/releases/download/version_${BINARYEN_REV}/binaryen-version_${BINARYEN_REV}-${ARCH}-${OS}.tar.gz" + install_tool "wasm-opt" $BINARYEN_REV $BINARYEN_DIR $BINARYEN_URL +} + +install_tool() { + TOOL=$1 + REV=$2 + TOOLCHAIN=$3 + URL=$4 + + archive="$TOOLCHAIN-$REV.tar.gz" + mark="$TOOLCHAIN.extracted" + extracted= + [ ! -e "$mark" ] || read -r extracted junk <$mark + + if [ "$extracted" = "$REV" ] && [ -e "$TOOLCHAIN/bin/$TOOL" ]; then + # Already extracted, continue silently + return 0 + fi + echo "" + + rm -f "$archive.new" "$TOOLCHAIN.extracted" + if [ ! -e "$archive" ]; then + log "Need to download $TOOL '$REV' from $URL." + curl -f -L -o "$archive.new" $URL + rm -f "$archive" + mv "$archive.new" "$archive" + fi + + log "Extracting $TOOL '$REV' into '$TOOLCHAIN'." >&2 + rm -rf "$TOOLCHAIN" + mkdir -p "$TOOLCHAIN" + (cd "$TOOLCHAIN" && tar --strip-components=1 -xf "$archive") + echo "$REV" >$mark +} + +log() { + echo "$@" >&2 +} + +if [ "${BINARYEN_DIR}" = "SKIP" ] || + [ "${OS}" != "macos" -a "${OS}" != "linux" ] || + [ "${ARCH}" != "x86_64" -a "${ARCH}" != "arm64" ]; then + log "Unsupported OS (${OS}) and architecture (${ARCH}) combination." + log "Using existing wasm-opt (`which wasm-opt`)." + exec wasm-opt "$@" +fi + +install_binaryen + +"$BINARYEN_DIR/bin/wasm-opt" "$@" diff --git a/tool/yarn b/tool/yarn index 6bb01d2f2..6357beda6 100755 --- a/tool/yarn +++ b/tool/yarn @@ -1,43 +1,43 @@ -#!/usr/bin/env bash -# Run a command with our local yarn install, rather than any globally installed -# instance. - -set -euo pipefail - -if [[ "${CI:-}" == "true" ]]; then - set -x -fi - -( - if [[ "${CI:-}" == "true" ]]; then - set -x - fi - - repo_root="${BASH_SOURCE%/*}/../" - cd "$repo_root" - - ./tool/node --version >/dev/null # Ensure node is unpacked and ready - - cachedir="$HOME/.cache/tailscale-yarn" - tarball="${cachedir}.tar.gz" - - read -r want_rev < "./tool/yarn.rev" - - got_rev="" - if [[ -x "${cachedir}/bin/yarn" ]]; then - got_rev=$(PATH="$HOME/.cache/tailscale-node/bin:$PATH" "${cachedir}/bin/yarn" --version) - fi - - if [[ "$want_rev" != "$got_rev" ]]; then - rm -rf "$cachedir" "$tarball" - mkdir -p "$cachedir" - curl -f -L -o "$tarball" "https://github.com/yarnpkg/yarn/releases/download/v${want_rev}/yarn-v${want_rev}.tar.gz" - (cd "$cachedir" && tar --strip-components=1 -xf "$tarball") - rm -f "$tarball" - fi -) - -# Deliberately not using cachedir here, to keep the environment -# completely pristine for execution of yarn. -export PATH="$HOME/.cache/tailscale-node/bin:$HOME/.cache/tailscale-yarn/bin:$PATH" -exec "$HOME/.cache/tailscale-yarn/bin/yarn" "$@" +#!/usr/bin/env bash +# Run a command with our local yarn install, rather than any globally installed +# instance. + +set -euo pipefail + +if [[ "${CI:-}" == "true" ]]; then + set -x +fi + +( + if [[ "${CI:-}" == "true" ]]; then + set -x + fi + + repo_root="${BASH_SOURCE%/*}/../" + cd "$repo_root" + + ./tool/node --version >/dev/null # Ensure node is unpacked and ready + + cachedir="$HOME/.cache/tailscale-yarn" + tarball="${cachedir}.tar.gz" + + read -r want_rev < "./tool/yarn.rev" + + got_rev="" + if [[ -x "${cachedir}/bin/yarn" ]]; then + got_rev=$(PATH="$HOME/.cache/tailscale-node/bin:$PATH" "${cachedir}/bin/yarn" --version) + fi + + if [[ "$want_rev" != "$got_rev" ]]; then + rm -rf "$cachedir" "$tarball" + mkdir -p "$cachedir" + curl -f -L -o "$tarball" "https://github.com/yarnpkg/yarn/releases/download/v${want_rev}/yarn-v${want_rev}.tar.gz" + (cd "$cachedir" && tar --strip-components=1 -xf "$tarball") + rm -f "$tarball" + fi +) + +# Deliberately not using cachedir here, to keep the environment +# completely pristine for execution of yarn. +export PATH="$HOME/.cache/tailscale-node/bin:$HOME/.cache/tailscale-yarn/bin:$PATH" +exec "$HOME/.cache/tailscale-yarn/bin/yarn" "$@" diff --git a/tool/yarn.rev b/tool/yarn.rev index 736c4acbd..de5856e86 100644 --- a/tool/yarn.rev +++ b/tool/yarn.rev @@ -1 +1 @@ -1.22.19 +1.22.19 diff --git a/tsnet/example/tshello/tshello.go b/tsnet/example/tshello/tshello.go index 2110c4d96..0cadcdd83 100644 --- a/tsnet/example/tshello/tshello.go +++ b/tsnet/example/tshello/tshello.go @@ -1,60 +1,60 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The tshello server demonstrates how to use Tailscale as a library. -package main - -import ( - "crypto/tls" - "flag" - "fmt" - "html" - "log" - "net/http" - "strings" - - "tailscale.com/tsnet" -) - -var ( - addr = flag.String("addr", ":80", "address to listen on") -) - -func main() { - flag.Parse() - s := new(tsnet.Server) - defer s.Close() - ln, err := s.Listen("tcp", *addr) - if err != nil { - log.Fatal(err) - } - defer ln.Close() - - lc, err := s.LocalClient() - if err != nil { - log.Fatal(err) - } - - if *addr == ":443" { - ln = tls.NewListener(ln, &tls.Config{ - GetCertificate: lc.GetCertificate, - }) - } - log.Fatal(http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - who, err := lc.WhoIs(r.Context(), r.RemoteAddr) - if err != nil { - http.Error(w, err.Error(), 500) - return - } - fmt.Fprintf(w, "

Hello, world!

\n") - fmt.Fprintf(w, "

You are %s from %s (%s)

", - html.EscapeString(who.UserProfile.LoginName), - html.EscapeString(firstLabel(who.Node.ComputedName)), - r.RemoteAddr) - }))) -} - -func firstLabel(s string) string { - s, _, _ = strings.Cut(s, ".") - return s -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The tshello server demonstrates how to use Tailscale as a library. +package main + +import ( + "crypto/tls" + "flag" + "fmt" + "html" + "log" + "net/http" + "strings" + + "tailscale.com/tsnet" +) + +var ( + addr = flag.String("addr", ":80", "address to listen on") +) + +func main() { + flag.Parse() + s := new(tsnet.Server) + defer s.Close() + ln, err := s.Listen("tcp", *addr) + if err != nil { + log.Fatal(err) + } + defer ln.Close() + + lc, err := s.LocalClient() + if err != nil { + log.Fatal(err) + } + + if *addr == ":443" { + ln = tls.NewListener(ln, &tls.Config{ + GetCertificate: lc.GetCertificate, + }) + } + log.Fatal(http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + who, err := lc.WhoIs(r.Context(), r.RemoteAddr) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + fmt.Fprintf(w, "

Hello, world!

\n") + fmt.Fprintf(w, "

You are %s from %s (%s)

", + html.EscapeString(who.UserProfile.LoginName), + html.EscapeString(firstLabel(who.Node.ComputedName)), + r.RemoteAddr) + }))) +} + +func firstLabel(s string) string { + s, _, _ = strings.Cut(s, ".") + return s +} diff --git a/tsnet/example/tsnet-http-client/tsnet-http-client.go b/tsnet/example/tsnet-http-client/tsnet-http-client.go index cda52eef7..9666fe999 100644 --- a/tsnet/example/tsnet-http-client/tsnet-http-client.go +++ b/tsnet/example/tsnet-http-client/tsnet-http-client.go @@ -1,44 +1,44 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The tshello server demonstrates how to use Tailscale as a library. -package main - -import ( - "flag" - "fmt" - "log" - "os" - "path/filepath" - - "tailscale.com/tsnet" -) - -func main() { - flag.Usage = func() { - fmt.Fprintf(os.Stderr, "Usage: %s \n", filepath.Base(os.Args[0])) - os.Exit(2) - } - flag.Parse() - - if flag.NArg() != 1 { - flag.Usage() - } - tailnetURL := flag.Arg(0) - - s := new(tsnet.Server) - defer s.Close() - - if err := s.Start(); err != nil { - log.Fatal(err) - } - - cli := s.HTTPClient() - - resp, err := cli.Get(tailnetURL) - if err != nil { - log.Fatal(err) - } - - resp.Write(os.Stdout) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The tshello server demonstrates how to use Tailscale as a library. +package main + +import ( + "flag" + "fmt" + "log" + "os" + "path/filepath" + + "tailscale.com/tsnet" +) + +func main() { + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s \n", filepath.Base(os.Args[0])) + os.Exit(2) + } + flag.Parse() + + if flag.NArg() != 1 { + flag.Usage() + } + tailnetURL := flag.Arg(0) + + s := new(tsnet.Server) + defer s.Close() + + if err := s.Start(); err != nil { + log.Fatal(err) + } + + cli := s.HTTPClient() + + resp, err := cli.Get(tailnetURL) + if err != nil { + log.Fatal(err) + } + + resp.Write(os.Stdout) +} diff --git a/tsnet/example/web-client/web-client.go b/tsnet/example/web-client/web-client.go index dee7fedfa..541efbaed 100644 --- a/tsnet/example/web-client/web-client.go +++ b/tsnet/example/web-client/web-client.go @@ -1,46 +1,46 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The web-client command demonstrates serving the Tailscale web client over tsnet. -package main - -import ( - "flag" - "log" - "net/http" - - "tailscale.com/client/web" - "tailscale.com/tsnet" -) - -var ( - addr = flag.String("addr", "localhost:8060", "address of Tailscale web client") -) - -func main() { - flag.Parse() - - s := &tsnet.Server{RunWebClient: true} - defer s.Close() - - lc, err := s.LocalClient() - if err != nil { - log.Fatal(err) - } - - // Serve the Tailscale web client. - ws, err := web.NewServer(web.ServerOpts{ - Mode: web.LoginServerMode, - LocalClient: lc, - }) - if err != nil { - log.Fatal(err) - } - defer ws.Shutdown() - log.Printf("Serving Tailscale web client on http://%s", *addr) - if err := http.ListenAndServe(*addr, ws); err != nil { - if err != http.ErrServerClosed { - log.Fatal(err) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The web-client command demonstrates serving the Tailscale web client over tsnet. +package main + +import ( + "flag" + "log" + "net/http" + + "tailscale.com/client/web" + "tailscale.com/tsnet" +) + +var ( + addr = flag.String("addr", "localhost:8060", "address of Tailscale web client") +) + +func main() { + flag.Parse() + + s := &tsnet.Server{RunWebClient: true} + defer s.Close() + + lc, err := s.LocalClient() + if err != nil { + log.Fatal(err) + } + + // Serve the Tailscale web client. + ws, err := web.NewServer(web.ServerOpts{ + Mode: web.LoginServerMode, + LocalClient: lc, + }) + if err != nil { + log.Fatal(err) + } + defer ws.Shutdown() + log.Printf("Serving Tailscale web client on http://%s", *addr) + if err := http.ListenAndServe(*addr, ws); err != nil { + if err != http.ErrServerClosed { + log.Fatal(err) + } + } +} diff --git a/tsnet/example_tshello_test.go b/tsnet/example_tshello_test.go index 4dec48233..d534bcfd1 100644 --- a/tsnet/example_tshello_test.go +++ b/tsnet/example_tshello_test.go @@ -1,72 +1,72 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tsnet_test - -import ( - "flag" - "fmt" - "html" - "log" - "net/http" - "strings" - - "tailscale.com/tsnet" -) - -func firstLabel(s string) string { - s, _, _ = strings.Cut(s, ".") - return s -} - -// Example_tshello is a full example on using tsnet. When you run this program it will print -// an authentication link. Open it in your favorite web browser and add it to your tailnet -// like any other machine. Open another terminal window and try to ping it: -// -// $ ping tshello -c 2 -// PING tshello (100.105.183.159) 56(84) bytes of data. -// 64 bytes from tshello.your-tailnet.ts.net (100.105.183.159): icmp_seq=1 ttl=64 time=25.0 ms -// 64 bytes from tshello.your-tailnet.ts.net (100.105.183.159): icmp_seq=2 ttl=64 time=1.12 ms -// -// Then connect to it using curl: -// -// $ curl http://tshello -//

Hello, world!

-//

You are Xe from pneuma (100.78.40.86:49214)

-// -// From here you can do anything you want with the Go standard library HTTP stack, or anything -// that is compatible with it (Gin/Gonic, Gorilla/mux, etc.). -func Example_tshello() { - var ( - addr = flag.String("addr", ":80", "address to listen on") - hostname = flag.String("hostname", "tshello", "hostname to use on the tailnet") - ) - - flag.Parse() - s := new(tsnet.Server) - s.Hostname = *hostname - defer s.Close() - ln, err := s.Listen("tcp", *addr) - if err != nil { - log.Fatal(err) - } - defer ln.Close() - - lc, err := s.LocalClient() - if err != nil { - log.Fatal(err) - } - - log.Fatal(http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - who, err := lc.WhoIs(r.Context(), r.RemoteAddr) - if err != nil { - http.Error(w, err.Error(), 500) - return - } - fmt.Fprintf(w, "

Hello, tailnet!

\n") - fmt.Fprintf(w, "

You are %s from %s (%s)

", - html.EscapeString(who.UserProfile.LoginName), - html.EscapeString(firstLabel(who.Node.ComputedName)), - r.RemoteAddr) - }))) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsnet_test + +import ( + "flag" + "fmt" + "html" + "log" + "net/http" + "strings" + + "tailscale.com/tsnet" +) + +func firstLabel(s string) string { + s, _, _ = strings.Cut(s, ".") + return s +} + +// Example_tshello is a full example on using tsnet. When you run this program it will print +// an authentication link. Open it in your favorite web browser and add it to your tailnet +// like any other machine. Open another terminal window and try to ping it: +// +// $ ping tshello -c 2 +// PING tshello (100.105.183.159) 56(84) bytes of data. +// 64 bytes from tshello.your-tailnet.ts.net (100.105.183.159): icmp_seq=1 ttl=64 time=25.0 ms +// 64 bytes from tshello.your-tailnet.ts.net (100.105.183.159): icmp_seq=2 ttl=64 time=1.12 ms +// +// Then connect to it using curl: +// +// $ curl http://tshello +//

Hello, world!

+//

You are Xe from pneuma (100.78.40.86:49214)

+// +// From here you can do anything you want with the Go standard library HTTP stack, or anything +// that is compatible with it (Gin/Gonic, Gorilla/mux, etc.). +func Example_tshello() { + var ( + addr = flag.String("addr", ":80", "address to listen on") + hostname = flag.String("hostname", "tshello", "hostname to use on the tailnet") + ) + + flag.Parse() + s := new(tsnet.Server) + s.Hostname = *hostname + defer s.Close() + ln, err := s.Listen("tcp", *addr) + if err != nil { + log.Fatal(err) + } + defer ln.Close() + + lc, err := s.LocalClient() + if err != nil { + log.Fatal(err) + } + + log.Fatal(http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + who, err := lc.WhoIs(r.Context(), r.RemoteAddr) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + fmt.Fprintf(w, "

Hello, tailnet!

\n") + fmt.Fprintf(w, "

You are %s from %s (%s)

", + html.EscapeString(who.UserProfile.LoginName), + html.EscapeString(firstLabel(who.Node.ComputedName)), + r.RemoteAddr) + }))) +} diff --git a/tstest/allocs.go b/tstest/allocs.go index a6d9c79f6..f15a00508 100644 --- a/tstest/allocs.go +++ b/tstest/allocs.go @@ -1,50 +1,50 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstest - -import ( - "fmt" - "runtime" - "testing" - "time" -) - -// MinAllocsPerRun asserts that f can run with no more than target allocations. -// It runs f up to 1000 times or 5s, whichever happens first. -// If f has executed more than target allocations on every run, it returns a non-nil error. -// -// MinAllocsPerRun sets GOMAXPROCS to 1 during its measurement and restores -// it before returning. -func MinAllocsPerRun(t *testing.T, target uint64, f func()) error { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) - - var memstats runtime.MemStats - var min, max, sum uint64 - start := time.Now() - var iters int - for { - runtime.ReadMemStats(&memstats) - startMallocs := memstats.Mallocs - f() - runtime.ReadMemStats(&memstats) - mallocs := memstats.Mallocs - startMallocs - // TODO: if mallocs < target, return an error? See discussion in #3204. - if mallocs <= target { - return nil - } - if min == 0 || mallocs < min { - min = mallocs - } - if mallocs > max { - max = mallocs - } - sum += mallocs - iters++ - if iters == 1000 || time.Since(start) > 5*time.Second { - break - } - } - - return fmt.Errorf("min allocs = %d, max allocs = %d, avg allocs/run = %f, want run with <= %d allocs", min, max, float64(sum)/float64(iters), target) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstest + +import ( + "fmt" + "runtime" + "testing" + "time" +) + +// MinAllocsPerRun asserts that f can run with no more than target allocations. +// It runs f up to 1000 times or 5s, whichever happens first. +// If f has executed more than target allocations on every run, it returns a non-nil error. +// +// MinAllocsPerRun sets GOMAXPROCS to 1 during its measurement and restores +// it before returning. +func MinAllocsPerRun(t *testing.T, target uint64, f func()) error { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) + + var memstats runtime.MemStats + var min, max, sum uint64 + start := time.Now() + var iters int + for { + runtime.ReadMemStats(&memstats) + startMallocs := memstats.Mallocs + f() + runtime.ReadMemStats(&memstats) + mallocs := memstats.Mallocs - startMallocs + // TODO: if mallocs < target, return an error? See discussion in #3204. + if mallocs <= target { + return nil + } + if min == 0 || mallocs < min { + min = mallocs + } + if mallocs > max { + max = mallocs + } + sum += mallocs + iters++ + if iters == 1000 || time.Since(start) > 5*time.Second { + break + } + } + + return fmt.Errorf("min allocs = %d, max allocs = %d, avg allocs/run = %f, want run with <= %d allocs", min, max, float64(sum)/float64(iters), target) +} diff --git a/tstest/archtest/qemu_test.go b/tstest/archtest/qemu_test.go index cea3b4b8e..8b59ae5d9 100644 --- a/tstest/archtest/qemu_test.go +++ b/tstest/archtest/qemu_test.go @@ -1,73 +1,73 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux && amd64 && !race - -package archtest - -import ( - "bytes" - "fmt" - "os" - "os/exec" - "strings" - "testing" - - "tailscale.com/util/cibuild" -) - -func TestInQemu(t *testing.T) { - t.Parallel() - type Arch struct { - Goarch string // GOARCH value - Qarch string // qemu name - } - arches := []Arch{ - {"arm", "arm"}, - {"arm64", "aarch64"}, - {"mips", "mips"}, - {"mipsle", "mipsel"}, - {"mips64", "mips64"}, - {"mips64le", "mips64el"}, - {"386", "386"}, - } - inCI := cibuild.On() - for _, arch := range arches { - arch := arch - t.Run(arch.Goarch, func(t *testing.T) { - t.Parallel() - qemuUser := "qemu-" + arch.Qarch - execVia := qemuUser - if arch.Goarch == "386" { - execVia = "" // amd64 can run it fine - } else { - look, err := exec.LookPath(qemuUser) - if err != nil { - if inCI { - t.Fatalf("in CI and qemu not available: %v", err) - } - t.Skipf("%s not found; skipping test. error was: %v", qemuUser, err) - } - t.Logf("using %v", look) - } - cmd := exec.Command("go", - "test", - "--exec="+execVia, - "-v", - "tailscale.com/tstest/archtest", - ) - cmd.Env = append(os.Environ(), "GOARCH="+arch.Goarch) - out, err := cmd.CombinedOutput() - if err != nil { - if strings.Contains(string(out), "fatal error: sigaction failed") && !inCI { - t.Skip("skipping; qemu too old. use 5.x.") - } - t.Errorf("failed: %s", out) - } - sub := fmt.Sprintf("I am linux/%s", arch.Goarch) - if !bytes.Contains(out, []byte(sub)) { - t.Errorf("output didn't contain %q: %s", sub, out) - } - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && amd64 && !race + +package archtest + +import ( + "bytes" + "fmt" + "os" + "os/exec" + "strings" + "testing" + + "tailscale.com/util/cibuild" +) + +func TestInQemu(t *testing.T) { + t.Parallel() + type Arch struct { + Goarch string // GOARCH value + Qarch string // qemu name + } + arches := []Arch{ + {"arm", "arm"}, + {"arm64", "aarch64"}, + {"mips", "mips"}, + {"mipsle", "mipsel"}, + {"mips64", "mips64"}, + {"mips64le", "mips64el"}, + {"386", "386"}, + } + inCI := cibuild.On() + for _, arch := range arches { + arch := arch + t.Run(arch.Goarch, func(t *testing.T) { + t.Parallel() + qemuUser := "qemu-" + arch.Qarch + execVia := qemuUser + if arch.Goarch == "386" { + execVia = "" // amd64 can run it fine + } else { + look, err := exec.LookPath(qemuUser) + if err != nil { + if inCI { + t.Fatalf("in CI and qemu not available: %v", err) + } + t.Skipf("%s not found; skipping test. error was: %v", qemuUser, err) + } + t.Logf("using %v", look) + } + cmd := exec.Command("go", + "test", + "--exec="+execVia, + "-v", + "tailscale.com/tstest/archtest", + ) + cmd.Env = append(os.Environ(), "GOARCH="+arch.Goarch) + out, err := cmd.CombinedOutput() + if err != nil { + if strings.Contains(string(out), "fatal error: sigaction failed") && !inCI { + t.Skip("skipping; qemu too old. use 5.x.") + } + t.Errorf("failed: %s", out) + } + sub := fmt.Sprintf("I am linux/%s", arch.Goarch) + if !bytes.Contains(out, []byte(sub)) { + t.Errorf("output didn't contain %q: %s", sub, out) + } + }) + } +} diff --git a/tstest/clock.go b/tstest/clock.go index 48684957e..ee7523430 100644 --- a/tstest/clock.go +++ b/tstest/clock.go @@ -1,694 +1,694 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstest - -import ( - "container/heap" - "sync" - "time" - - "tailscale.com/tstime" - "tailscale.com/util/mak" -) - -// ClockOpts is used to configure the initial settings for a Clock. Once the -// settings are configured as desired, call NewClock to get the resulting Clock. -type ClockOpts struct { - // Start is the starting time for the Clock. When FollowRealTime is false, - // Start is also the value that will be returned by the first call - // to Clock.Now. - Start time.Time - // Step is the amount of time the Clock will advance whenever Clock.Now is - // called. If set to zero, the Clock will only advance when Clock.Advance is - // called and/or if FollowRealTime is true. - // - // FollowRealTime and Step cannot be enabled at the same time. - Step time.Duration - - // TimerChannelSize configures the maximum buffered ticks that are - // permitted in the channel of any Timer and Ticker created by this Clock. - // The special value 0 means to use the default of 1. The buffer may need to - // be increased if time is advanced by more than a single tick and proper - // functioning of the test requires that the ticks are not lost. - TimerChannelSize int - - // FollowRealTime makes the simulated time increment along with real time. - // It is a compromise between determinism and the difficulty of explicitly - // managing the simulated time via Step or Clock.Advance. When - // FollowRealTime is set, calls to Now() and PeekNow() will add the - // elapsed real-world time to the simulated time. - // - // FollowRealTime and Step cannot be enabled at the same time. - FollowRealTime bool -} - -// NewClock creates a Clock with the specified settings. To create a -// Clock with only the default settings, new(Clock) is equivalent, except that -// the start time will not be computed until one of the receivers is called. -func NewClock(co ClockOpts) *Clock { - if co.FollowRealTime && co.Step != 0 { - panic("only one of FollowRealTime and Step are allowed in NewClock") - } - - return newClockInternal(co, nil) -} - -// newClockInternal creates a Clock with the specified settings and allows -// specifying a non-standard realTimeClock. -func newClockInternal(co ClockOpts, rtClock tstime.Clock) *Clock { - if !co.FollowRealTime && rtClock != nil { - panic("rtClock can only be set with FollowRealTime enabled") - } - - if co.FollowRealTime && rtClock == nil { - rtClock = new(tstime.StdClock) - } - - c := &Clock{ - start: co.Start, - realTimeClock: rtClock, - step: co.Step, - timerChannelSize: co.TimerChannelSize, - } - c.init() // init now to capture the current time when co.Start.IsZero() - return c -} - -// Clock is a testing clock that advances every time its Now method is -// called, beginning at its start time. If no start time is specified using -// ClockBuilder, an arbitrary start time will be selected when the Clock is -// created and can be retrieved by calling Clock.Start(). -type Clock struct { - // start is the first value returned by Now. It must not be modified after - // init is called. - start time.Time - - // realTimeClock, if not nil, indicates that the Clock shall move forward - // according to realTimeClock + the accumulated calls to Advance. This can - // make writing tests easier that require some control over the clock but do - // not need exact control over the clock. While step can also be used for - // this purpose, it is harder to control how quickly time moves using step. - realTimeClock tstime.Clock - - initOnce sync.Once - mu sync.Mutex - - // step is how much to advance with each Now call. - step time.Duration - // present is the last value returned by Now (and will be returned again by - // PeekNow). - present time.Time - // realTime is the time from realTimeClock corresponding to the current - // value of present. - realTime time.Time - // skipStep indicates that the next call to Now should not add step to - // present. This occurs after initialization and after Advance. - skipStep bool - // timerChannelSize is the buffer size to use for channels created by - // NewTimer and NewTicker. - timerChannelSize int - - events eventManager -} - -func (c *Clock) init() { - c.initOnce.Do(func() { - if c.realTimeClock != nil { - c.realTime = c.realTimeClock.Now() - } - if c.start.IsZero() { - if c.realTime.IsZero() { - c.start = time.Now() - } else { - c.start = c.realTime - } - } - if c.timerChannelSize == 0 { - c.timerChannelSize = 1 - } - c.present = c.start - c.skipStep = true - c.events.AdvanceTo(c.present) - }) -} - -// Now returns the virtual clock's current time, and advances it -// according to its step configuration. -func (c *Clock) Now() time.Time { - c.init() - rt := c.maybeGetRealTime() - - c.mu.Lock() - defer c.mu.Unlock() - - step := c.step - if c.skipStep { - step = 0 - c.skipStep = false - } - c.advanceLocked(rt, step) - - return c.present -} - -func (c *Clock) maybeGetRealTime() time.Time { - if c.realTimeClock == nil { - return time.Time{} - } - return c.realTimeClock.Now() -} - -func (c *Clock) advanceLocked(now time.Time, add time.Duration) { - if !now.IsZero() { - add += now.Sub(c.realTime) - c.realTime = now - } - if add == 0 { - return - } - c.present = c.present.Add(add) - c.events.AdvanceTo(c.present) -} - -// PeekNow returns the last time reported by Now. If Now has never been called, -// PeekNow returns the same value as GetStart. -func (c *Clock) PeekNow() time.Time { - c.init() - c.mu.Lock() - defer c.mu.Unlock() - return c.present -} - -// Advance moves simulated time forward or backwards by a relative amount. Any -// Timer or Ticker that is waiting will fire at the requested point in simulated -// time. Advance returns the new simulated time. If this Clock follows real time -// then the next call to Now will equal the return value of Advance + the -// elapsed time since calling Advance. Otherwise, the next call to Now will -// equal the return value of Advance, regardless of the current step. -func (c *Clock) Advance(d time.Duration) time.Time { - c.init() - rt := c.maybeGetRealTime() - - c.mu.Lock() - defer c.mu.Unlock() - c.skipStep = true - - c.advanceLocked(rt, d) - return c.present -} - -// AdvanceTo moves simulated time to a new absolute value. Any Timer or Ticker -// that is waiting will fire at the requested point in simulated time. If this -// Clock follows real time then the next call to Now will equal t + the elapsed -// time since calling Advance. Otherwise, the next call to Now will equal t, -// regardless of the configured step. -func (c *Clock) AdvanceTo(t time.Time) { - c.init() - rt := c.maybeGetRealTime() - - c.mu.Lock() - defer c.mu.Unlock() - c.skipStep = true - c.realTime = rt - c.present = t - c.events.AdvanceTo(c.present) -} - -// GetStart returns the initial simulated time when this Clock was created. -func (c *Clock) GetStart() time.Time { - c.init() - c.mu.Lock() - defer c.mu.Unlock() - return c.start -} - -// GetStep returns the amount that simulated time advances on every call to Now. -func (c *Clock) GetStep() time.Duration { - c.init() - c.mu.Lock() - defer c.mu.Unlock() - return c.step -} - -// SetStep updates the amount that simulated time advances on every call to Now. -func (c *Clock) SetStep(d time.Duration) { - c.init() - c.mu.Lock() - defer c.mu.Unlock() - c.step = d -} - -// SetTimerChannelSize changes the channel size for any Timer or Ticker created -// in the future. It does not affect those that were already created. -func (c *Clock) SetTimerChannelSize(n int) { - c.init() - c.mu.Lock() - defer c.mu.Unlock() - c.timerChannelSize = n -} - -// NewTicker returns a Ticker that uses this Clock for accessing the current -// time. -func (c *Clock) NewTicker(d time.Duration) (tstime.TickerController, <-chan time.Time) { - c.init() - rt := c.maybeGetRealTime() - - c.mu.Lock() - defer c.mu.Unlock() - - c.advanceLocked(rt, 0) - t := &Ticker{ - nextTrigger: c.present.Add(d), - period: d, - em: &c.events, - } - t.init(c.timerChannelSize) - return t, t.C -} - -// NewTimer returns a Timer that uses this Clock for accessing the current -// time. -func (c *Clock) NewTimer(d time.Duration) (tstime.TimerController, <-chan time.Time) { - c.init() - rt := c.maybeGetRealTime() - - c.mu.Lock() - defer c.mu.Unlock() - - c.advanceLocked(rt, 0) - t := &Timer{ - nextTrigger: c.present.Add(d), - em: &c.events, - } - t.init(c.timerChannelSize, nil) - return t, t.C -} - -// AfterFunc returns a Timer that calls f when it fires, using this Clock for -// accessing the current time. -func (c *Clock) AfterFunc(d time.Duration, f func()) tstime.TimerController { - c.init() - rt := c.maybeGetRealTime() - - c.mu.Lock() - defer c.mu.Unlock() - - c.advanceLocked(rt, 0) - t := &Timer{ - nextTrigger: c.present.Add(d), - em: &c.events, - } - t.init(c.timerChannelSize, f) - return t -} - -// Since subtracts specified duration from Now(). -func (c *Clock) Since(t time.Time) time.Duration { - return c.Now().Sub(t) -} - -// eventHandler offers a common interface for Timer and Ticker events to avoid -// code duplication in eventManager. -type eventHandler interface { - // Fire signals the event. The provided time is written to the event's - // channel as the current time. The return value is the next time this event - // should fire, otherwise if it is zero then the event will be removed from - // the eventManager. - Fire(time.Time) time.Time -} - -// event tracks details about an upcoming Timer or Ticker firing. -type event struct { - position int // The current index in the heap, needed for heap.Fix and heap.Remove. - when time.Time // A cache of the next time the event triggers to avoid locking issues if we were to get it from eh. - eh eventHandler -} - -// eventManager tracks pending events created by Timer and Ticker. eventManager -// implements heap.Interface for efficient lookups of the next event. -type eventManager struct { - // clock is a real time clock for scheduling events with. When clock is nil, - // events only fire when AdvanceTo is called by the simulated clock that - // this eventManager belongs to. When clock is not nil, events may fire when - // timer triggers. - clock tstime.Clock - - mu sync.Mutex - now time.Time - heap []*event - reverseLookup map[eventHandler]*event - - // timer is an AfterFunc that triggers at heap[0].when.Sub(now) relative to - // the time represented by clock. In other words, if clock is real world - // time, then if an event is scheduled 1 second into the future in the - // simulated time, then the event will trigger after 1 second of actual test - // execution time (unless the test advances simulated time, in which case - // the timer is updated accordingly). This makes tests easier to write in - // situations where the simulated time only needs to be partially - // controlled, and the test writer wishes for simulated time to pass with an - // offset but still synchronized with the real world. - // - // In the future, this could be extended to allow simulated time to run at a - // multiple of real world time. - timer tstime.TimerController -} - -func (em *eventManager) handleTimer() { - rt := em.clock.Now() - em.AdvanceTo(rt) -} - -// Push implements heap.Interface.Push and must only be called by heap funcs -// with em.mu already held. -func (em *eventManager) Push(x any) { - e, ok := x.(*event) - if !ok { - panic("incorrect event type") - } - if e == nil { - panic("nil event") - } - - mak.Set(&em.reverseLookup, e.eh, e) - e.position = len(em.heap) - em.heap = append(em.heap, e) -} - -// Pop implements heap.Interface.Pop and must only be called by heap funcs with -// em.mu already held. -func (em *eventManager) Pop() any { - e := em.heap[len(em.heap)-1] - em.heap = em.heap[:len(em.heap)-1] - delete(em.reverseLookup, e.eh) - return e -} - -// Len implements sort.Interface.Len and must only be called by heap funcs with -// em.mu already held. -func (em *eventManager) Len() int { - return len(em.heap) -} - -// Less implements sort.Interface.Less and must only be called by heap funcs -// with em.mu already held. -func (em *eventManager) Less(i, j int) bool { - return em.heap[i].when.Before(em.heap[j].when) -} - -// Swap implements sort.Interface.Swap and must only be called by heap funcs -// with em.mu already held. -func (em *eventManager) Swap(i, j int) { - em.heap[i], em.heap[j] = em.heap[j], em.heap[i] - em.heap[i].position = i - em.heap[j].position = j -} - -// Reschedule adds/updates/deletes an event in the heap, whichever -// operation is applicable (use a zero time to delete). -func (em *eventManager) Reschedule(eh eventHandler, t time.Time) { - em.mu.Lock() - defer em.mu.Unlock() - defer em.updateTimerLocked() - - e, ok := em.reverseLookup[eh] - if !ok { - if t.IsZero() { - // eh is not scheduled and also not active, so do nothing. - return - } - // eh is not scheduled but is active, so add it. - heap.Push(em, &event{ - when: t, - eh: eh, - }) - em.processEventsLocked(em.now) // This is always safe and required when !t.After(em.now). - return - } - - if t.IsZero() { - // e is scheduled but not active, so remove it. - heap.Remove(em, e.position) - return - } - - // e is scheduled and active, so update it. - e.when = t - heap.Fix(em, e.position) - em.processEventsLocked(em.now) // This is always safe and required when !t.After(em.now). -} - -// AdvanceTo updates the current time to tm and fires all events scheduled -// before or equal to tm. When an event fires, it may request rescheduling and -// the rescheduled events will be combined with the other existing events that -// are waiting, and will be run in the unified ordering. A poorly behaved event -// may theoretically prevent this from ever completing, but both Timer and -// Ticker require positive steps into the future. -func (em *eventManager) AdvanceTo(tm time.Time) { - em.mu.Lock() - defer em.mu.Unlock() - defer em.updateTimerLocked() - - em.processEventsLocked(tm) - em.now = tm -} - -// Now returns the cached current time. It is intended for use by a Timer or -// Ticker that needs to convert a relative time to an absolute time. -func (em *eventManager) Now() time.Time { - em.mu.Lock() - defer em.mu.Unlock() - return em.now -} - -func (em *eventManager) processEventsLocked(tm time.Time) { - for len(em.heap) > 0 && !em.heap[0].when.After(tm) { - // Ideally some jitter would be added here but it's difficult to do so - // in a deterministic fashion. - em.now = em.heap[0].when - - if nextFire := em.heap[0].eh.Fire(em.now); !nextFire.IsZero() { - em.heap[0].when = nextFire - heap.Fix(em, 0) - } else { - heap.Pop(em) - } - } -} - -func (em *eventManager) updateTimerLocked() { - if em.clock == nil { - return - } - if len(em.heap) == 0 { - if em.timer != nil { - em.timer.Stop() - } - return - } - - timeToEvent := em.heap[0].when.Sub(em.now) - if em.timer == nil { - em.timer = em.clock.AfterFunc(timeToEvent, em.handleTimer) - return - } - em.timer.Reset(timeToEvent) -} - -// Ticker is a time.Ticker lookalike for use in tests that need to control when -// events fire. Ticker could be made standalone in future but for now is -// expected to be paired with a Clock and created by Clock.NewTicker. -type Ticker struct { - C <-chan time.Time // The channel on which ticks are delivered. - - // em is the eventManager to be notified when nextTrigger changes. - // eventManager has its own mutex, and the pointer is immutable, therefore - // em can be accessed without holding mu. - em *eventManager - - c chan<- time.Time // The writer side of C. - - mu sync.Mutex - - // nextTrigger is the time of the ticker's next scheduled activation. When - // Fire activates the ticker, nextTrigger is the timestamp written to the - // channel. - nextTrigger time.Time - - // period is the duration that is added to nextTrigger when the ticker - // fires. - period time.Duration -} - -func (t *Ticker) init(channelSize int) { - if channelSize <= 0 { - panic("ticker channel size must be non-negative") - } - c := make(chan time.Time, channelSize) - t.c = c - t.C = c - t.em.Reschedule(t, t.nextTrigger) -} - -// Fire triggers the ticker. curTime is the timestamp to write to the channel. -// The next trigger time for the ticker is updated to the last computed trigger -// time + the ticker period (set at creation or using Reset). The next trigger -// time is computed this way to match standard time.Ticker behavior, which -// prevents accumulation of long term drift caused by delays in event execution. -func (t *Ticker) Fire(curTime time.Time) time.Time { - t.mu.Lock() - defer t.mu.Unlock() - - if t.nextTrigger.IsZero() { - return time.Time{} - } - select { - case t.c <- curTime: - default: - } - t.nextTrigger = t.nextTrigger.Add(t.period) - - return t.nextTrigger -} - -// Reset adjusts the Ticker's period to d and reschedules the next fire time to -// the current simulated time + d. -func (t *Ticker) Reset(d time.Duration) { - if d <= 0 { - // The standard time.Ticker requires a positive period. - panic("non-positive period for Ticker.Reset") - } - - now := t.em.Now() - - t.mu.Lock() - t.resetLocked(now.Add(d), d) - t.mu.Unlock() - - t.em.Reschedule(t, t.nextTrigger) -} - -// ResetAbsolute adjusts the Ticker's period to d and reschedules the next fire -// time to nextTrigger. -func (t *Ticker) ResetAbsolute(nextTrigger time.Time, d time.Duration) { - if nextTrigger.IsZero() { - panic("zero nextTrigger time for ResetAbsolute") - } - if d <= 0 { - panic("non-positive period for ResetAbsolute") - } - - t.mu.Lock() - t.resetLocked(nextTrigger, d) - t.mu.Unlock() - - t.em.Reschedule(t, t.nextTrigger) -} - -func (t *Ticker) resetLocked(nextTrigger time.Time, d time.Duration) { - t.nextTrigger = nextTrigger - t.period = d -} - -// Stop deactivates the Ticker. -func (t *Ticker) Stop() { - t.mu.Lock() - t.nextTrigger = time.Time{} - t.mu.Unlock() - - t.em.Reschedule(t, t.nextTrigger) -} - -// Timer is a time.Timer lookalike for use in tests that need to control when -// events fire. Timer could be made standalone in future but for now must be -// paired with a Clock and created by Clock.NewTimer. -type Timer struct { - C <-chan time.Time // The channel on which ticks are delivered. - - // em is the eventManager to be notified when nextTrigger changes. - // eventManager has its own mutex, and the pointer is immutable, therefore - // em can be accessed without holding mu. - em *eventManager - - f func(time.Time) // The function to call when the timer expires. - - mu sync.Mutex - - // nextTrigger is the time of the ticker's next scheduled activation. When - // Fire activates the ticker, nextTrigger is the timestamp written to the - // channel. - nextTrigger time.Time -} - -func (t *Timer) init(channelSize int, afterFunc func()) { - if channelSize <= 0 { - panic("ticker channel size must be non-negative") - } - c := make(chan time.Time, channelSize) - t.C = c - if afterFunc == nil { - t.f = func(curTime time.Time) { - select { - case c <- curTime: - default: - } - } - } else { - t.f = func(_ time.Time) { afterFunc() } - } - t.em.Reschedule(t, t.nextTrigger) -} - -// Fire triggers the ticker. curTime is the timestamp to write to the channel. -// The next trigger time for the ticker is updated to the last computed trigger -// time + the ticker period (set at creation or using Reset). The next trigger -// time is computed this way to match standard time.Ticker behavior, which -// prevents accumulation of long term drift caused by delays in event execution. -func (t *Timer) Fire(curTime time.Time) time.Time { - t.mu.Lock() - defer t.mu.Unlock() - - if t.nextTrigger.IsZero() { - return time.Time{} - } - t.nextTrigger = time.Time{} - t.f(curTime) - return time.Time{} -} - -// Reset reschedules the next fire time to the current simulated time + d. -// Reset reports whether the timer was still active before the reset. -func (t *Timer) Reset(d time.Duration) bool { - if d <= 0 { - // The standard time.Timer requires a positive delay. - panic("non-positive delay for Timer.Reset") - } - - return t.reset(t.em.Now().Add(d)) -} - -// ResetAbsolute reschedules the next fire time to nextTrigger. -// ResetAbsolute reports whether the timer was still active before the reset. -func (t *Timer) ResetAbsolute(nextTrigger time.Time) bool { - if nextTrigger.IsZero() { - panic("zero nextTrigger time for ResetAbsolute") - } - - return t.reset(nextTrigger) -} - -// Stop deactivates the Timer. Stop reports whether the timer was active before -// stopping. -func (t *Timer) Stop() bool { - return t.reset(time.Time{}) -} - -func (t *Timer) reset(nextTrigger time.Time) bool { - t.mu.Lock() - wasActive := !t.nextTrigger.IsZero() - t.nextTrigger = nextTrigger - t.mu.Unlock() - - t.em.Reschedule(t, t.nextTrigger) - return wasActive -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstest + +import ( + "container/heap" + "sync" + "time" + + "tailscale.com/tstime" + "tailscale.com/util/mak" +) + +// ClockOpts is used to configure the initial settings for a Clock. Once the +// settings are configured as desired, call NewClock to get the resulting Clock. +type ClockOpts struct { + // Start is the starting time for the Clock. When FollowRealTime is false, + // Start is also the value that will be returned by the first call + // to Clock.Now. + Start time.Time + // Step is the amount of time the Clock will advance whenever Clock.Now is + // called. If set to zero, the Clock will only advance when Clock.Advance is + // called and/or if FollowRealTime is true. + // + // FollowRealTime and Step cannot be enabled at the same time. + Step time.Duration + + // TimerChannelSize configures the maximum buffered ticks that are + // permitted in the channel of any Timer and Ticker created by this Clock. + // The special value 0 means to use the default of 1. The buffer may need to + // be increased if time is advanced by more than a single tick and proper + // functioning of the test requires that the ticks are not lost. + TimerChannelSize int + + // FollowRealTime makes the simulated time increment along with real time. + // It is a compromise between determinism and the difficulty of explicitly + // managing the simulated time via Step or Clock.Advance. When + // FollowRealTime is set, calls to Now() and PeekNow() will add the + // elapsed real-world time to the simulated time. + // + // FollowRealTime and Step cannot be enabled at the same time. + FollowRealTime bool +} + +// NewClock creates a Clock with the specified settings. To create a +// Clock with only the default settings, new(Clock) is equivalent, except that +// the start time will not be computed until one of the receivers is called. +func NewClock(co ClockOpts) *Clock { + if co.FollowRealTime && co.Step != 0 { + panic("only one of FollowRealTime and Step are allowed in NewClock") + } + + return newClockInternal(co, nil) +} + +// newClockInternal creates a Clock with the specified settings and allows +// specifying a non-standard realTimeClock. +func newClockInternal(co ClockOpts, rtClock tstime.Clock) *Clock { + if !co.FollowRealTime && rtClock != nil { + panic("rtClock can only be set with FollowRealTime enabled") + } + + if co.FollowRealTime && rtClock == nil { + rtClock = new(tstime.StdClock) + } + + c := &Clock{ + start: co.Start, + realTimeClock: rtClock, + step: co.Step, + timerChannelSize: co.TimerChannelSize, + } + c.init() // init now to capture the current time when co.Start.IsZero() + return c +} + +// Clock is a testing clock that advances every time its Now method is +// called, beginning at its start time. If no start time is specified using +// ClockBuilder, an arbitrary start time will be selected when the Clock is +// created and can be retrieved by calling Clock.Start(). +type Clock struct { + // start is the first value returned by Now. It must not be modified after + // init is called. + start time.Time + + // realTimeClock, if not nil, indicates that the Clock shall move forward + // according to realTimeClock + the accumulated calls to Advance. This can + // make writing tests easier that require some control over the clock but do + // not need exact control over the clock. While step can also be used for + // this purpose, it is harder to control how quickly time moves using step. + realTimeClock tstime.Clock + + initOnce sync.Once + mu sync.Mutex + + // step is how much to advance with each Now call. + step time.Duration + // present is the last value returned by Now (and will be returned again by + // PeekNow). + present time.Time + // realTime is the time from realTimeClock corresponding to the current + // value of present. + realTime time.Time + // skipStep indicates that the next call to Now should not add step to + // present. This occurs after initialization and after Advance. + skipStep bool + // timerChannelSize is the buffer size to use for channels created by + // NewTimer and NewTicker. + timerChannelSize int + + events eventManager +} + +func (c *Clock) init() { + c.initOnce.Do(func() { + if c.realTimeClock != nil { + c.realTime = c.realTimeClock.Now() + } + if c.start.IsZero() { + if c.realTime.IsZero() { + c.start = time.Now() + } else { + c.start = c.realTime + } + } + if c.timerChannelSize == 0 { + c.timerChannelSize = 1 + } + c.present = c.start + c.skipStep = true + c.events.AdvanceTo(c.present) + }) +} + +// Now returns the virtual clock's current time, and advances it +// according to its step configuration. +func (c *Clock) Now() time.Time { + c.init() + rt := c.maybeGetRealTime() + + c.mu.Lock() + defer c.mu.Unlock() + + step := c.step + if c.skipStep { + step = 0 + c.skipStep = false + } + c.advanceLocked(rt, step) + + return c.present +} + +func (c *Clock) maybeGetRealTime() time.Time { + if c.realTimeClock == nil { + return time.Time{} + } + return c.realTimeClock.Now() +} + +func (c *Clock) advanceLocked(now time.Time, add time.Duration) { + if !now.IsZero() { + add += now.Sub(c.realTime) + c.realTime = now + } + if add == 0 { + return + } + c.present = c.present.Add(add) + c.events.AdvanceTo(c.present) +} + +// PeekNow returns the last time reported by Now. If Now has never been called, +// PeekNow returns the same value as GetStart. +func (c *Clock) PeekNow() time.Time { + c.init() + c.mu.Lock() + defer c.mu.Unlock() + return c.present +} + +// Advance moves simulated time forward or backwards by a relative amount. Any +// Timer or Ticker that is waiting will fire at the requested point in simulated +// time. Advance returns the new simulated time. If this Clock follows real time +// then the next call to Now will equal the return value of Advance + the +// elapsed time since calling Advance. Otherwise, the next call to Now will +// equal the return value of Advance, regardless of the current step. +func (c *Clock) Advance(d time.Duration) time.Time { + c.init() + rt := c.maybeGetRealTime() + + c.mu.Lock() + defer c.mu.Unlock() + c.skipStep = true + + c.advanceLocked(rt, d) + return c.present +} + +// AdvanceTo moves simulated time to a new absolute value. Any Timer or Ticker +// that is waiting will fire at the requested point in simulated time. If this +// Clock follows real time then the next call to Now will equal t + the elapsed +// time since calling Advance. Otherwise, the next call to Now will equal t, +// regardless of the configured step. +func (c *Clock) AdvanceTo(t time.Time) { + c.init() + rt := c.maybeGetRealTime() + + c.mu.Lock() + defer c.mu.Unlock() + c.skipStep = true + c.realTime = rt + c.present = t + c.events.AdvanceTo(c.present) +} + +// GetStart returns the initial simulated time when this Clock was created. +func (c *Clock) GetStart() time.Time { + c.init() + c.mu.Lock() + defer c.mu.Unlock() + return c.start +} + +// GetStep returns the amount that simulated time advances on every call to Now. +func (c *Clock) GetStep() time.Duration { + c.init() + c.mu.Lock() + defer c.mu.Unlock() + return c.step +} + +// SetStep updates the amount that simulated time advances on every call to Now. +func (c *Clock) SetStep(d time.Duration) { + c.init() + c.mu.Lock() + defer c.mu.Unlock() + c.step = d +} + +// SetTimerChannelSize changes the channel size for any Timer or Ticker created +// in the future. It does not affect those that were already created. +func (c *Clock) SetTimerChannelSize(n int) { + c.init() + c.mu.Lock() + defer c.mu.Unlock() + c.timerChannelSize = n +} + +// NewTicker returns a Ticker that uses this Clock for accessing the current +// time. +func (c *Clock) NewTicker(d time.Duration) (tstime.TickerController, <-chan time.Time) { + c.init() + rt := c.maybeGetRealTime() + + c.mu.Lock() + defer c.mu.Unlock() + + c.advanceLocked(rt, 0) + t := &Ticker{ + nextTrigger: c.present.Add(d), + period: d, + em: &c.events, + } + t.init(c.timerChannelSize) + return t, t.C +} + +// NewTimer returns a Timer that uses this Clock for accessing the current +// time. +func (c *Clock) NewTimer(d time.Duration) (tstime.TimerController, <-chan time.Time) { + c.init() + rt := c.maybeGetRealTime() + + c.mu.Lock() + defer c.mu.Unlock() + + c.advanceLocked(rt, 0) + t := &Timer{ + nextTrigger: c.present.Add(d), + em: &c.events, + } + t.init(c.timerChannelSize, nil) + return t, t.C +} + +// AfterFunc returns a Timer that calls f when it fires, using this Clock for +// accessing the current time. +func (c *Clock) AfterFunc(d time.Duration, f func()) tstime.TimerController { + c.init() + rt := c.maybeGetRealTime() + + c.mu.Lock() + defer c.mu.Unlock() + + c.advanceLocked(rt, 0) + t := &Timer{ + nextTrigger: c.present.Add(d), + em: &c.events, + } + t.init(c.timerChannelSize, f) + return t +} + +// Since subtracts specified duration from Now(). +func (c *Clock) Since(t time.Time) time.Duration { + return c.Now().Sub(t) +} + +// eventHandler offers a common interface for Timer and Ticker events to avoid +// code duplication in eventManager. +type eventHandler interface { + // Fire signals the event. The provided time is written to the event's + // channel as the current time. The return value is the next time this event + // should fire, otherwise if it is zero then the event will be removed from + // the eventManager. + Fire(time.Time) time.Time +} + +// event tracks details about an upcoming Timer or Ticker firing. +type event struct { + position int // The current index in the heap, needed for heap.Fix and heap.Remove. + when time.Time // A cache of the next time the event triggers to avoid locking issues if we were to get it from eh. + eh eventHandler +} + +// eventManager tracks pending events created by Timer and Ticker. eventManager +// implements heap.Interface for efficient lookups of the next event. +type eventManager struct { + // clock is a real time clock for scheduling events with. When clock is nil, + // events only fire when AdvanceTo is called by the simulated clock that + // this eventManager belongs to. When clock is not nil, events may fire when + // timer triggers. + clock tstime.Clock + + mu sync.Mutex + now time.Time + heap []*event + reverseLookup map[eventHandler]*event + + // timer is an AfterFunc that triggers at heap[0].when.Sub(now) relative to + // the time represented by clock. In other words, if clock is real world + // time, then if an event is scheduled 1 second into the future in the + // simulated time, then the event will trigger after 1 second of actual test + // execution time (unless the test advances simulated time, in which case + // the timer is updated accordingly). This makes tests easier to write in + // situations where the simulated time only needs to be partially + // controlled, and the test writer wishes for simulated time to pass with an + // offset but still synchronized with the real world. + // + // In the future, this could be extended to allow simulated time to run at a + // multiple of real world time. + timer tstime.TimerController +} + +func (em *eventManager) handleTimer() { + rt := em.clock.Now() + em.AdvanceTo(rt) +} + +// Push implements heap.Interface.Push and must only be called by heap funcs +// with em.mu already held. +func (em *eventManager) Push(x any) { + e, ok := x.(*event) + if !ok { + panic("incorrect event type") + } + if e == nil { + panic("nil event") + } + + mak.Set(&em.reverseLookup, e.eh, e) + e.position = len(em.heap) + em.heap = append(em.heap, e) +} + +// Pop implements heap.Interface.Pop and must only be called by heap funcs with +// em.mu already held. +func (em *eventManager) Pop() any { + e := em.heap[len(em.heap)-1] + em.heap = em.heap[:len(em.heap)-1] + delete(em.reverseLookup, e.eh) + return e +} + +// Len implements sort.Interface.Len and must only be called by heap funcs with +// em.mu already held. +func (em *eventManager) Len() int { + return len(em.heap) +} + +// Less implements sort.Interface.Less and must only be called by heap funcs +// with em.mu already held. +func (em *eventManager) Less(i, j int) bool { + return em.heap[i].when.Before(em.heap[j].when) +} + +// Swap implements sort.Interface.Swap and must only be called by heap funcs +// with em.mu already held. +func (em *eventManager) Swap(i, j int) { + em.heap[i], em.heap[j] = em.heap[j], em.heap[i] + em.heap[i].position = i + em.heap[j].position = j +} + +// Reschedule adds/updates/deletes an event in the heap, whichever +// operation is applicable (use a zero time to delete). +func (em *eventManager) Reschedule(eh eventHandler, t time.Time) { + em.mu.Lock() + defer em.mu.Unlock() + defer em.updateTimerLocked() + + e, ok := em.reverseLookup[eh] + if !ok { + if t.IsZero() { + // eh is not scheduled and also not active, so do nothing. + return + } + // eh is not scheduled but is active, so add it. + heap.Push(em, &event{ + when: t, + eh: eh, + }) + em.processEventsLocked(em.now) // This is always safe and required when !t.After(em.now). + return + } + + if t.IsZero() { + // e is scheduled but not active, so remove it. + heap.Remove(em, e.position) + return + } + + // e is scheduled and active, so update it. + e.when = t + heap.Fix(em, e.position) + em.processEventsLocked(em.now) // This is always safe and required when !t.After(em.now). +} + +// AdvanceTo updates the current time to tm and fires all events scheduled +// before or equal to tm. When an event fires, it may request rescheduling and +// the rescheduled events will be combined with the other existing events that +// are waiting, and will be run in the unified ordering. A poorly behaved event +// may theoretically prevent this from ever completing, but both Timer and +// Ticker require positive steps into the future. +func (em *eventManager) AdvanceTo(tm time.Time) { + em.mu.Lock() + defer em.mu.Unlock() + defer em.updateTimerLocked() + + em.processEventsLocked(tm) + em.now = tm +} + +// Now returns the cached current time. It is intended for use by a Timer or +// Ticker that needs to convert a relative time to an absolute time. +func (em *eventManager) Now() time.Time { + em.mu.Lock() + defer em.mu.Unlock() + return em.now +} + +func (em *eventManager) processEventsLocked(tm time.Time) { + for len(em.heap) > 0 && !em.heap[0].when.After(tm) { + // Ideally some jitter would be added here but it's difficult to do so + // in a deterministic fashion. + em.now = em.heap[0].when + + if nextFire := em.heap[0].eh.Fire(em.now); !nextFire.IsZero() { + em.heap[0].when = nextFire + heap.Fix(em, 0) + } else { + heap.Pop(em) + } + } +} + +func (em *eventManager) updateTimerLocked() { + if em.clock == nil { + return + } + if len(em.heap) == 0 { + if em.timer != nil { + em.timer.Stop() + } + return + } + + timeToEvent := em.heap[0].when.Sub(em.now) + if em.timer == nil { + em.timer = em.clock.AfterFunc(timeToEvent, em.handleTimer) + return + } + em.timer.Reset(timeToEvent) +} + +// Ticker is a time.Ticker lookalike for use in tests that need to control when +// events fire. Ticker could be made standalone in future but for now is +// expected to be paired with a Clock and created by Clock.NewTicker. +type Ticker struct { + C <-chan time.Time // The channel on which ticks are delivered. + + // em is the eventManager to be notified when nextTrigger changes. + // eventManager has its own mutex, and the pointer is immutable, therefore + // em can be accessed without holding mu. + em *eventManager + + c chan<- time.Time // The writer side of C. + + mu sync.Mutex + + // nextTrigger is the time of the ticker's next scheduled activation. When + // Fire activates the ticker, nextTrigger is the timestamp written to the + // channel. + nextTrigger time.Time + + // period is the duration that is added to nextTrigger when the ticker + // fires. + period time.Duration +} + +func (t *Ticker) init(channelSize int) { + if channelSize <= 0 { + panic("ticker channel size must be non-negative") + } + c := make(chan time.Time, channelSize) + t.c = c + t.C = c + t.em.Reschedule(t, t.nextTrigger) +} + +// Fire triggers the ticker. curTime is the timestamp to write to the channel. +// The next trigger time for the ticker is updated to the last computed trigger +// time + the ticker period (set at creation or using Reset). The next trigger +// time is computed this way to match standard time.Ticker behavior, which +// prevents accumulation of long term drift caused by delays in event execution. +func (t *Ticker) Fire(curTime time.Time) time.Time { + t.mu.Lock() + defer t.mu.Unlock() + + if t.nextTrigger.IsZero() { + return time.Time{} + } + select { + case t.c <- curTime: + default: + } + t.nextTrigger = t.nextTrigger.Add(t.period) + + return t.nextTrigger +} + +// Reset adjusts the Ticker's period to d and reschedules the next fire time to +// the current simulated time + d. +func (t *Ticker) Reset(d time.Duration) { + if d <= 0 { + // The standard time.Ticker requires a positive period. + panic("non-positive period for Ticker.Reset") + } + + now := t.em.Now() + + t.mu.Lock() + t.resetLocked(now.Add(d), d) + t.mu.Unlock() + + t.em.Reschedule(t, t.nextTrigger) +} + +// ResetAbsolute adjusts the Ticker's period to d and reschedules the next fire +// time to nextTrigger. +func (t *Ticker) ResetAbsolute(nextTrigger time.Time, d time.Duration) { + if nextTrigger.IsZero() { + panic("zero nextTrigger time for ResetAbsolute") + } + if d <= 0 { + panic("non-positive period for ResetAbsolute") + } + + t.mu.Lock() + t.resetLocked(nextTrigger, d) + t.mu.Unlock() + + t.em.Reschedule(t, t.nextTrigger) +} + +func (t *Ticker) resetLocked(nextTrigger time.Time, d time.Duration) { + t.nextTrigger = nextTrigger + t.period = d +} + +// Stop deactivates the Ticker. +func (t *Ticker) Stop() { + t.mu.Lock() + t.nextTrigger = time.Time{} + t.mu.Unlock() + + t.em.Reschedule(t, t.nextTrigger) +} + +// Timer is a time.Timer lookalike for use in tests that need to control when +// events fire. Timer could be made standalone in future but for now must be +// paired with a Clock and created by Clock.NewTimer. +type Timer struct { + C <-chan time.Time // The channel on which ticks are delivered. + + // em is the eventManager to be notified when nextTrigger changes. + // eventManager has its own mutex, and the pointer is immutable, therefore + // em can be accessed without holding mu. + em *eventManager + + f func(time.Time) // The function to call when the timer expires. + + mu sync.Mutex + + // nextTrigger is the time of the ticker's next scheduled activation. When + // Fire activates the ticker, nextTrigger is the timestamp written to the + // channel. + nextTrigger time.Time +} + +func (t *Timer) init(channelSize int, afterFunc func()) { + if channelSize <= 0 { + panic("ticker channel size must be non-negative") + } + c := make(chan time.Time, channelSize) + t.C = c + if afterFunc == nil { + t.f = func(curTime time.Time) { + select { + case c <- curTime: + default: + } + } + } else { + t.f = func(_ time.Time) { afterFunc() } + } + t.em.Reschedule(t, t.nextTrigger) +} + +// Fire triggers the ticker. curTime is the timestamp to write to the channel. +// The next trigger time for the ticker is updated to the last computed trigger +// time + the ticker period (set at creation or using Reset). The next trigger +// time is computed this way to match standard time.Ticker behavior, which +// prevents accumulation of long term drift caused by delays in event execution. +func (t *Timer) Fire(curTime time.Time) time.Time { + t.mu.Lock() + defer t.mu.Unlock() + + if t.nextTrigger.IsZero() { + return time.Time{} + } + t.nextTrigger = time.Time{} + t.f(curTime) + return time.Time{} +} + +// Reset reschedules the next fire time to the current simulated time + d. +// Reset reports whether the timer was still active before the reset. +func (t *Timer) Reset(d time.Duration) bool { + if d <= 0 { + // The standard time.Timer requires a positive delay. + panic("non-positive delay for Timer.Reset") + } + + return t.reset(t.em.Now().Add(d)) +} + +// ResetAbsolute reschedules the next fire time to nextTrigger. +// ResetAbsolute reports whether the timer was still active before the reset. +func (t *Timer) ResetAbsolute(nextTrigger time.Time) bool { + if nextTrigger.IsZero() { + panic("zero nextTrigger time for ResetAbsolute") + } + + return t.reset(nextTrigger) +} + +// Stop deactivates the Timer. Stop reports whether the timer was active before +// stopping. +func (t *Timer) Stop() bool { + return t.reset(time.Time{}) +} + +func (t *Timer) reset(nextTrigger time.Time) bool { + t.mu.Lock() + wasActive := !t.nextTrigger.IsZero() + t.nextTrigger = nextTrigger + t.mu.Unlock() + + t.em.Reschedule(t, t.nextTrigger) + return wasActive +} diff --git a/tstest/deptest/deptest_test.go b/tstest/deptest/deptest_test.go index 3b7b2dde9..ebafa5684 100644 --- a/tstest/deptest/deptest_test.go +++ b/tstest/deptest/deptest_test.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package deptest - -import "testing" - -func TestImports(t *testing.T) { - ImportAliasCheck(t, "../../") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package deptest + +import "testing" + +func TestImports(t *testing.T) { + ImportAliasCheck(t, "../../") +} diff --git a/tstest/integration/gen_deps.go b/tstest/integration/gen_deps.go index ab5cc0448..23bb95ee5 100644 --- a/tstest/integration/gen_deps.go +++ b/tstest/integration/gen_deps.go @@ -1,65 +1,65 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ignore - -package main - -import ( - "bytes" - "encoding/json" - "fmt" - "log" - "os" - "os/exec" - "strings" -) - -func main() { - for _, goos := range []string{"windows", "linux", "darwin", "freebsd", "openbsd"} { - generate(goos) - } -} - -func generate(goos string) { - var x struct { - Imports []string - } - cmd := exec.Command("go", "list", "-json", "tailscale.com/cmd/tailscaled") - cmd.Env = append(os.Environ(), "GOOS="+goos, "GOARCH=amd64") - j, err := cmd.Output() - if err != nil { - log.Fatalf("GOOS=%s GOARCH=amd64 %s: %v", goos, cmd, err) - } - if err := json.Unmarshal(j, &x); err != nil { - log.Fatal(err) - } - var out bytes.Buffer - out.WriteString(`// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Code generated by gen_deps.go; DO NOT EDIT. - -package integration - -import ( - // And depend on a bunch of tailscaled innards, for Go's test caching. - // Otherwise cmd/go never sees that we depend on these packages' - // transitive deps when we run "go install tailscaled" in a child - // process and can cache a prior success when a dependency changes. -`) - for _, dep := range x.Imports { - if !strings.Contains(dep, ".") { - // Omit standard library deps. - continue - } - fmt.Fprintf(&out, "\t_ %q\n", dep) - } - fmt.Fprintf(&out, ")\n") - - filename := fmt.Sprintf("tailscaled_deps_test_%s.go", goos) - err = os.WriteFile(filename, out.Bytes(), 0644) - if err != nil { - log.Fatal(err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "log" + "os" + "os/exec" + "strings" +) + +func main() { + for _, goos := range []string{"windows", "linux", "darwin", "freebsd", "openbsd"} { + generate(goos) + } +} + +func generate(goos string) { + var x struct { + Imports []string + } + cmd := exec.Command("go", "list", "-json", "tailscale.com/cmd/tailscaled") + cmd.Env = append(os.Environ(), "GOOS="+goos, "GOARCH=amd64") + j, err := cmd.Output() + if err != nil { + log.Fatalf("GOOS=%s GOARCH=amd64 %s: %v", goos, cmd, err) + } + if err := json.Unmarshal(j, &x); err != nil { + log.Fatal(err) + } + var out bytes.Buffer + out.WriteString(`// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen_deps.go; DO NOT EDIT. + +package integration + +import ( + // And depend on a bunch of tailscaled innards, for Go's test caching. + // Otherwise cmd/go never sees that we depend on these packages' + // transitive deps when we run "go install tailscaled" in a child + // process and can cache a prior success when a dependency changes. +`) + for _, dep := range x.Imports { + if !strings.Contains(dep, ".") { + // Omit standard library deps. + continue + } + fmt.Fprintf(&out, "\t_ %q\n", dep) + } + fmt.Fprintf(&out, ")\n") + + filename := fmt.Sprintf("tailscaled_deps_test_%s.go", goos) + err = os.WriteFile(filename, out.Bytes(), 0644) + if err != nil { + log.Fatal(err) + } +} diff --git a/tstest/integration/vms/README.md b/tstest/integration/vms/README.md index 766d8e574..519c3d000 100644 --- a/tstest/integration/vms/README.md +++ b/tstest/integration/vms/README.md @@ -1,95 +1,95 @@ -# End-to-End VM-based Integration Testing - -This test spins up a bunch of common linux distributions and then tries to get -them to connect to a -[`testcontrol`](https://pkg.go.dev/tailscale.com/tstest/integration/testcontrol) -server. - -## Running - -This test currently only runs on Linux. - -This test depends on the following command line tools: - -- [qemu](https://www.qemu.org/) -- [cdrkit](https://en.wikipedia.org/wiki/Cdrkit) -- [openssh](https://www.openssh.com/) - -This test also requires the following: - -- about 10 GB of temporary storage -- about 10 GB of cached VM images -- at least 4 GB of ram for virtual machines -- hardware virtualization support - ([KVM](https://www.linux-kvm.org/page/Main_Page)) enabled in the BIOS -- the `kvm` module to be loaded (`modprobe kvm`) -- the user running these tests must have access to `/dev/kvm` (being in the - `kvm` group should suffice) - -The `--no-s3` flag is needed to disable downloads from S3, which require -credentials. However keep in mind that some distributions do not use stable URLs -for each individual image artifact, so there may be spurious test failures as a -result. - -If you are using [Nix](https://nixos.org), you can run all of the tests with the -correct command line tools using this command: - -```console -$ nix-shell -p nixos-generators -p openssh -p go -p qemu -p cdrkit --run "go test . --run-vm-tests --v --timeout 30m --no-s3" -``` - -Keep the timeout high for the first run, especially if you are not downloading -VM images from S3. The mirrors we pull images from have download rate limits and -will take a while to download. - -Because of the hardware requirements of this test, this test will not run -without the `--run-vm-tests` flag set. - -## Other Fun Flags - -This test's behavior is customized with command line flags. - -### Don't Download Images From S3 - -If you pass the `-no-s3` flag to `go test`, the S3 step will be skipped in favor -of downloading the images directly from upstream sources, which may cause the -test to fail in odd places. - -### Distribution Picking - -This test runs on a large number of distributions. By default it tries to run -everything, which may or may not be ideal for you. If you only want to test a -subset of distributions, you can use the `--distro-regex` flag to match a subset -of distributions using a [regular expression](https://golang.org/pkg/regexp/) -such as like this: - -```console -$ go test -run-vm-tests -distro-regex centos -``` - -This would run all tests on all versions of CentOS. - -```console -$ go test -run-vm-tests -distro-regex '(debian|ubuntu)' -``` - -This would run all tests on all versions of Debian and Ubuntu. - -### Ram Limiting - -This test uses a lot of memory. In order to avoid making machines run out of -memory running this test, a semaphore is used to limit how many megabytes of ram -are being used at once. By default this semaphore is set to 4096 MB of ram -(about 4 gigabytes). You can customize this with the `--ram-limit` flag: - -```console -$ go test --run-vm-tests --ram-limit 2048 -$ go test --run-vm-tests --ram-limit 65536 -``` - -The first example will set the limit to 2048 MB of ram (about 2 gigabytes). The -second example will set the limit to 65536 MB of ram (about 65 gigabytes). -Please be careful with this flag, improper usage of it is known to cause the -Linux out-of-memory killer to engage. Try to keep it within 50-75% of your -machine's available ram (there is some overhead involved with the -virtualization) to be on the safe side. +# End-to-End VM-based Integration Testing + +This test spins up a bunch of common linux distributions and then tries to get +them to connect to a +[`testcontrol`](https://pkg.go.dev/tailscale.com/tstest/integration/testcontrol) +server. + +## Running + +This test currently only runs on Linux. + +This test depends on the following command line tools: + +- [qemu](https://www.qemu.org/) +- [cdrkit](https://en.wikipedia.org/wiki/Cdrkit) +- [openssh](https://www.openssh.com/) + +This test also requires the following: + +- about 10 GB of temporary storage +- about 10 GB of cached VM images +- at least 4 GB of ram for virtual machines +- hardware virtualization support + ([KVM](https://www.linux-kvm.org/page/Main_Page)) enabled in the BIOS +- the `kvm` module to be loaded (`modprobe kvm`) +- the user running these tests must have access to `/dev/kvm` (being in the + `kvm` group should suffice) + +The `--no-s3` flag is needed to disable downloads from S3, which require +credentials. However keep in mind that some distributions do not use stable URLs +for each individual image artifact, so there may be spurious test failures as a +result. + +If you are using [Nix](https://nixos.org), you can run all of the tests with the +correct command line tools using this command: + +```console +$ nix-shell -p nixos-generators -p openssh -p go -p qemu -p cdrkit --run "go test . --run-vm-tests --v --timeout 30m --no-s3" +``` + +Keep the timeout high for the first run, especially if you are not downloading +VM images from S3. The mirrors we pull images from have download rate limits and +will take a while to download. + +Because of the hardware requirements of this test, this test will not run +without the `--run-vm-tests` flag set. + +## Other Fun Flags + +This test's behavior is customized with command line flags. + +### Don't Download Images From S3 + +If you pass the `-no-s3` flag to `go test`, the S3 step will be skipped in favor +of downloading the images directly from upstream sources, which may cause the +test to fail in odd places. + +### Distribution Picking + +This test runs on a large number of distributions. By default it tries to run +everything, which may or may not be ideal for you. If you only want to test a +subset of distributions, you can use the `--distro-regex` flag to match a subset +of distributions using a [regular expression](https://golang.org/pkg/regexp/) +such as like this: + +```console +$ go test -run-vm-tests -distro-regex centos +``` + +This would run all tests on all versions of CentOS. + +```console +$ go test -run-vm-tests -distro-regex '(debian|ubuntu)' +``` + +This would run all tests on all versions of Debian and Ubuntu. + +### Ram Limiting + +This test uses a lot of memory. In order to avoid making machines run out of +memory running this test, a semaphore is used to limit how many megabytes of ram +are being used at once. By default this semaphore is set to 4096 MB of ram +(about 4 gigabytes). You can customize this with the `--ram-limit` flag: + +```console +$ go test --run-vm-tests --ram-limit 2048 +$ go test --run-vm-tests --ram-limit 65536 +``` + +The first example will set the limit to 2048 MB of ram (about 2 gigabytes). The +second example will set the limit to 65536 MB of ram (about 65 gigabytes). +Please be careful with this flag, improper usage of it is known to cause the +Linux out-of-memory killer to engage. Try to keep it within 50-75% of your +machine's available ram (there is some overhead involved with the +virtualization) to be on the safe side. diff --git a/tstest/integration/vms/distros.hujson b/tstest/integration/vms/distros.hujson index 5634d6d67..049091ed5 100644 --- a/tstest/integration/vms/distros.hujson +++ b/tstest/integration/vms/distros.hujson @@ -1,39 +1,39 @@ -// NOTE(Xe): If you run into issues getting the autoconfig to work, run -// this test with the flag `--distro-regex=alpine-edge`. Connect with a VNC -// client with a command like this: -// -// $ vncviewer :0 -// -// On NixOS you can get away with something like this: -// -// $ env NIXPKGS_ALLOW_UNFREE=1 nix-shell -p tigervnc --run 'vncviewer :0' -// -// Login as root with the password root. Then look in -// /var/log/cloud-init-output.log for what you messed up. -[ - { - "Name": "ubuntu-18-04", - "URL": "https://cloud-images.ubuntu.com/releases/bionic/release-20210817/ubuntu-18.04-server-cloudimg-amd64.img", - "SHA256Sum": "1ee1039f0b91c8367351413b5b5f56026aaf302fd5f66f17f8215132d6e946d2", - "MemoryMegs": 512, - "PackageManager": "apt", - "InitSystem": "systemd" - }, - { - "Name": "ubuntu-20-04", - "URL": "https://cloud-images.ubuntu.com/releases/focal/release-20210819/ubuntu-20.04-server-cloudimg-amd64.img", - "SHA256Sum": "99e25e6e344e3a50a081235e825937238a3d51b099969e107ef66f0d3a1f955e", - "MemoryMegs": 512, - "PackageManager": "apt", - "InitSystem": "systemd" - }, - { - "Name": "nixos-21-11", - "URL": "channel:nixos-21.11", - "SHA256Sum": "lolfakesha", - "MemoryMegs": 512, - "PackageManager": "nix", - "InitSystem": "systemd", - "HostGenerated": true - }, -] +// NOTE(Xe): If you run into issues getting the autoconfig to work, run +// this test with the flag `--distro-regex=alpine-edge`. Connect with a VNC +// client with a command like this: +// +// $ vncviewer :0 +// +// On NixOS you can get away with something like this: +// +// $ env NIXPKGS_ALLOW_UNFREE=1 nix-shell -p tigervnc --run 'vncviewer :0' +// +// Login as root with the password root. Then look in +// /var/log/cloud-init-output.log for what you messed up. +[ + { + "Name": "ubuntu-18-04", + "URL": "https://cloud-images.ubuntu.com/releases/bionic/release-20210817/ubuntu-18.04-server-cloudimg-amd64.img", + "SHA256Sum": "1ee1039f0b91c8367351413b5b5f56026aaf302fd5f66f17f8215132d6e946d2", + "MemoryMegs": 512, + "PackageManager": "apt", + "InitSystem": "systemd" + }, + { + "Name": "ubuntu-20-04", + "URL": "https://cloud-images.ubuntu.com/releases/focal/release-20210819/ubuntu-20.04-server-cloudimg-amd64.img", + "SHA256Sum": "99e25e6e344e3a50a081235e825937238a3d51b099969e107ef66f0d3a1f955e", + "MemoryMegs": 512, + "PackageManager": "apt", + "InitSystem": "systemd" + }, + { + "Name": "nixos-21-11", + "URL": "channel:nixos-21.11", + "SHA256Sum": "lolfakesha", + "MemoryMegs": 512, + "PackageManager": "nix", + "InitSystem": "systemd", + "HostGenerated": true + }, +] diff --git a/tstest/integration/vms/distros_test.go b/tstest/integration/vms/distros_test.go index db3bae793..462aa2a6b 100644 --- a/tstest/integration/vms/distros_test.go +++ b/tstest/integration/vms/distros_test.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package vms - -import ( - "testing" -) - -func TestDistrosGotLoaded(t *testing.T) { - if len(Distros) == 0 { - t.Fatal("no distros were loaded") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package vms + +import ( + "testing" +) + +func TestDistrosGotLoaded(t *testing.T) { + if len(Distros) == 0 { + t.Fatal("no distros were loaded") + } +} diff --git a/tstest/integration/vms/dns_tester.go b/tstest/integration/vms/dns_tester.go index be7d7ee6d..50b39bb5f 100644 --- a/tstest/integration/vms/dns_tester.go +++ b/tstest/integration/vms/dns_tester.go @@ -1,54 +1,54 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ignore - -// Command dns_tester exists in order to perform tests of our DNS -// configuration stack. This was written because the state of DNS -// in our target environments is so diverse that we need a little tool -// to do this test for us. -package main - -import ( - "context" - "encoding/json" - "flag" - "net" - "os" - "time" -) - -func main() { - flag.Parse() - target := flag.Arg(0) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - errCount := 0 - wait := 25 * time.Millisecond - for range make([]struct{}, 5) { - err := lookup(ctx, target) - if err != nil { - errCount++ - time.Sleep(wait) - wait = wait * 2 - continue - } - - break - } -} - -func lookup(ctx context.Context, target string) error { - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - - hosts, err := net.LookupHost(target) - if err != nil { - return err - } - - json.NewEncoder(os.Stdout).Encode(hosts) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +// Command dns_tester exists in order to perform tests of our DNS +// configuration stack. This was written because the state of DNS +// in our target environments is so diverse that we need a little tool +// to do this test for us. +package main + +import ( + "context" + "encoding/json" + "flag" + "net" + "os" + "time" +) + +func main() { + flag.Parse() + target := flag.Arg(0) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errCount := 0 + wait := 25 * time.Millisecond + for range make([]struct{}, 5) { + err := lookup(ctx, target) + if err != nil { + errCount++ + time.Sleep(wait) + wait = wait * 2 + continue + } + + break + } +} + +func lookup(ctx context.Context, target string) error { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + hosts, err := net.LookupHost(target) + if err != nil { + return err + } + + json.NewEncoder(os.Stdout).Encode(hosts) + return nil +} diff --git a/tstest/integration/vms/doc.go b/tstest/integration/vms/doc.go index 3008493ea..6093b53ac 100644 --- a/tstest/integration/vms/doc.go +++ b/tstest/integration/vms/doc.go @@ -1,6 +1,6 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package vms does VM-based integration/functional tests by using -// qemu and a bank of pre-made VM images. -package vms +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package vms does VM-based integration/functional tests by using +// qemu and a bank of pre-made VM images. +package vms diff --git a/tstest/integration/vms/harness_test.go b/tstest/integration/vms/harness_test.go index 620276ac2..1e080414d 100644 --- a/tstest/integration/vms/harness_test.go +++ b/tstest/integration/vms/harness_test.go @@ -1,242 +1,242 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !plan9 - -package vms - -import ( - "bytes" - "context" - "fmt" - "log" - "net" - "net/http" - "net/netip" - "os" - "os/exec" - "path" - "path/filepath" - "strconv" - "sync" - "testing" - "time" - - "golang.org/x/crypto/ssh" - "golang.org/x/net/proxy" - "tailscale.com/tailcfg" - "tailscale.com/tstest/integration" - "tailscale.com/tstest/integration/testcontrol" - "tailscale.com/types/dnstype" -) - -type Harness struct { - testerDialer proxy.Dialer - testerDir string - binaryDir string - cli string - daemon string - pubKey string - signer ssh.Signer - cs *testcontrol.Server - loginServerURL string - testerV4 netip.Addr - ipMu *sync.Mutex - ipMap map[string]ipMapping -} - -func newHarness(t *testing.T) *Harness { - dir := t.TempDir() - bindHost := deriveBindhost(t) - ln, err := net.Listen("tcp", net.JoinHostPort(bindHost, "0")) - if err != nil { - t.Fatalf("can't make TCP listener: %v", err) - } - t.Cleanup(func() { - ln.Close() - }) - t.Logf("host:port: %s", ln.Addr()) - - cs := &testcontrol.Server{ - DNSConfig: &tailcfg.DNSConfig{ - // TODO: this is wrong. - // It is also only one of many configurations. - // Figure out how to scale it up. - Resolvers: []*dnstype.Resolver{{Addr: "100.100.100.100"}, {Addr: "8.8.8.8"}}, - Domains: []string{"record"}, - Proxied: true, - ExtraRecords: []tailcfg.DNSRecord{{Name: "extratest.record", Type: "A", Value: "1.2.3.4"}}, - }, - } - - derpMap := integration.RunDERPAndSTUN(t, t.Logf, bindHost) - cs.DERPMap = derpMap - - var ( - ipMu sync.Mutex - ipMap = map[string]ipMapping{} - ) - - mux := http.NewServeMux() - mux.Handle("/", cs) - - lc := &integration.LogCatcher{} - if *verboseLogcatcher { - lc.UseLogf(t.Logf) - t.Cleanup(func() { - lc.UseLogf(nil) // do not log after test is complete - }) - } - mux.Handle("/c/", lc) - - // This handler will let the virtual machines tell the host information about that VM. - // This is used to maintain a list of port->IP address mappings that are known to be - // working. This allows later steps to connect over SSH. This returns no response to - // clients because no response is needed. - mux.HandleFunc("/myip/", func(w http.ResponseWriter, r *http.Request) { - ipMu.Lock() - defer ipMu.Unlock() - - name := path.Base(r.URL.Path) - host, _, _ := net.SplitHostPort(r.RemoteAddr) - port, err := strconv.Atoi(name) - if err != nil { - log.Panicf("bad port: %v", port) - } - distro := r.UserAgent() - ipMap[distro] = ipMapping{distro, port, host} - t.Logf("%s: %v", name, host) - }) - - hs := &http.Server{Handler: mux} - go hs.Serve(ln) - - cmd := exec.Command("ssh-keygen", "-t", "ed25519", "-f", "machinekey", "-N", "") - cmd.Dir = dir - if out, err := cmd.CombinedOutput(); err != nil { - t.Fatalf("ssh-keygen: %v, %s", err, out) - } - pubkey, err := os.ReadFile(filepath.Join(dir, "machinekey.pub")) - if err != nil { - t.Fatalf("can't read ssh key: %v", err) - } - - privateKey, err := os.ReadFile(filepath.Join(dir, "machinekey")) - if err != nil { - t.Fatalf("can't read ssh private key: %v", err) - } - - signer, err := ssh.ParsePrivateKey(privateKey) - if err != nil { - t.Fatalf("can't parse private key: %v", err) - } - - loginServer := fmt.Sprintf("http://%s", ln.Addr()) - t.Logf("loginServer: %s", loginServer) - - h := &Harness{ - pubKey: string(pubkey), - binaryDir: integration.BinaryDir(t), - cli: integration.TailscaleBinary(t), - daemon: integration.TailscaledBinary(t), - signer: signer, - loginServerURL: loginServer, - cs: cs, - ipMu: &ipMu, - ipMap: ipMap, - } - - h.makeTestNode(t, loginServer) - - return h -} - -func (h *Harness) Tailscale(t *testing.T, args ...string) []byte { - t.Helper() - - args = append([]string{"--socket=" + filepath.Join(h.testerDir, "sock")}, args...) - - cmd := exec.Command(h.cli, args...) - out, err := cmd.CombinedOutput() - if err != nil { - t.Fatal(err) - } - - return out -} - -// makeTestNode creates a userspace tailscaled running in netstack mode that -// enables us to make connections to and from the tailscale network being -// tested. This mutates the Harness to allow tests to dial into the tailscale -// network as well as control the tester's tailscaled. -func (h *Harness) makeTestNode(t *testing.T, controlURL string) { - dir := t.TempDir() - h.testerDir = dir - - port, err := getProbablyFreePortNumber() - if err != nil { - t.Fatalf("can't get free port: %v", err) - } - - cmd := exec.Command( - h.daemon, - "--tun=userspace-networking", - "--state="+filepath.Join(dir, "state.json"), - "--socket="+filepath.Join(dir, "sock"), - fmt.Sprintf("--socks5-server=localhost:%d", port), - ) - - cmd.Env = append( - os.Environ(), - "NOTIFY_SOCKET="+filepath.Join(dir, "notify_socket"), - "TS_LOG_TARGET="+h.loginServerURL, - ) - - err = cmd.Start() - if err != nil { - t.Fatalf("can't start tailscaled: %v", err) - } - - t.Cleanup(func() { - cmd.Process.Kill() - }) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - ticker := time.NewTicker(100 * time.Millisecond) - -outer: - for { - select { - case <-ctx.Done(): - t.Fatal("timed out waiting for tailscaled to come up") - return - case <-ticker.C: - conn, err := net.Dial("unix", filepath.Join(dir, "sock")) - if err != nil { - continue - } - - conn.Close() - break outer - } - } - - run(t, dir, h.cli, - "--socket="+filepath.Join(dir, "sock"), - "up", - "--login-server="+controlURL, - "--hostname=tester", - ) - - dialer, err := proxy.SOCKS5("tcp", net.JoinHostPort("127.0.0.1", fmt.Sprint(port)), nil, &net.Dialer{}) - if err != nil { - t.Fatalf("can't make netstack proxy dialer: %v", err) - } - h.testerDialer = dialer - h.testerV4 = bytes2Netaddr(h.Tailscale(t, "ip", "-4")) -} - -func bytes2Netaddr(inp []byte) netip.Addr { - return netip.MustParseAddr(string(bytes.TrimSpace(inp))) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !plan9 + +package vms + +import ( + "bytes" + "context" + "fmt" + "log" + "net" + "net/http" + "net/netip" + "os" + "os/exec" + "path" + "path/filepath" + "strconv" + "sync" + "testing" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/net/proxy" + "tailscale.com/tailcfg" + "tailscale.com/tstest/integration" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/types/dnstype" +) + +type Harness struct { + testerDialer proxy.Dialer + testerDir string + binaryDir string + cli string + daemon string + pubKey string + signer ssh.Signer + cs *testcontrol.Server + loginServerURL string + testerV4 netip.Addr + ipMu *sync.Mutex + ipMap map[string]ipMapping +} + +func newHarness(t *testing.T) *Harness { + dir := t.TempDir() + bindHost := deriveBindhost(t) + ln, err := net.Listen("tcp", net.JoinHostPort(bindHost, "0")) + if err != nil { + t.Fatalf("can't make TCP listener: %v", err) + } + t.Cleanup(func() { + ln.Close() + }) + t.Logf("host:port: %s", ln.Addr()) + + cs := &testcontrol.Server{ + DNSConfig: &tailcfg.DNSConfig{ + // TODO: this is wrong. + // It is also only one of many configurations. + // Figure out how to scale it up. + Resolvers: []*dnstype.Resolver{{Addr: "100.100.100.100"}, {Addr: "8.8.8.8"}}, + Domains: []string{"record"}, + Proxied: true, + ExtraRecords: []tailcfg.DNSRecord{{Name: "extratest.record", Type: "A", Value: "1.2.3.4"}}, + }, + } + + derpMap := integration.RunDERPAndSTUN(t, t.Logf, bindHost) + cs.DERPMap = derpMap + + var ( + ipMu sync.Mutex + ipMap = map[string]ipMapping{} + ) + + mux := http.NewServeMux() + mux.Handle("/", cs) + + lc := &integration.LogCatcher{} + if *verboseLogcatcher { + lc.UseLogf(t.Logf) + t.Cleanup(func() { + lc.UseLogf(nil) // do not log after test is complete + }) + } + mux.Handle("/c/", lc) + + // This handler will let the virtual machines tell the host information about that VM. + // This is used to maintain a list of port->IP address mappings that are known to be + // working. This allows later steps to connect over SSH. This returns no response to + // clients because no response is needed. + mux.HandleFunc("/myip/", func(w http.ResponseWriter, r *http.Request) { + ipMu.Lock() + defer ipMu.Unlock() + + name := path.Base(r.URL.Path) + host, _, _ := net.SplitHostPort(r.RemoteAddr) + port, err := strconv.Atoi(name) + if err != nil { + log.Panicf("bad port: %v", port) + } + distro := r.UserAgent() + ipMap[distro] = ipMapping{distro, port, host} + t.Logf("%s: %v", name, host) + }) + + hs := &http.Server{Handler: mux} + go hs.Serve(ln) + + cmd := exec.Command("ssh-keygen", "-t", "ed25519", "-f", "machinekey", "-N", "") + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("ssh-keygen: %v, %s", err, out) + } + pubkey, err := os.ReadFile(filepath.Join(dir, "machinekey.pub")) + if err != nil { + t.Fatalf("can't read ssh key: %v", err) + } + + privateKey, err := os.ReadFile(filepath.Join(dir, "machinekey")) + if err != nil { + t.Fatalf("can't read ssh private key: %v", err) + } + + signer, err := ssh.ParsePrivateKey(privateKey) + if err != nil { + t.Fatalf("can't parse private key: %v", err) + } + + loginServer := fmt.Sprintf("http://%s", ln.Addr()) + t.Logf("loginServer: %s", loginServer) + + h := &Harness{ + pubKey: string(pubkey), + binaryDir: integration.BinaryDir(t), + cli: integration.TailscaleBinary(t), + daemon: integration.TailscaledBinary(t), + signer: signer, + loginServerURL: loginServer, + cs: cs, + ipMu: &ipMu, + ipMap: ipMap, + } + + h.makeTestNode(t, loginServer) + + return h +} + +func (h *Harness) Tailscale(t *testing.T, args ...string) []byte { + t.Helper() + + args = append([]string{"--socket=" + filepath.Join(h.testerDir, "sock")}, args...) + + cmd := exec.Command(h.cli, args...) + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatal(err) + } + + return out +} + +// makeTestNode creates a userspace tailscaled running in netstack mode that +// enables us to make connections to and from the tailscale network being +// tested. This mutates the Harness to allow tests to dial into the tailscale +// network as well as control the tester's tailscaled. +func (h *Harness) makeTestNode(t *testing.T, controlURL string) { + dir := t.TempDir() + h.testerDir = dir + + port, err := getProbablyFreePortNumber() + if err != nil { + t.Fatalf("can't get free port: %v", err) + } + + cmd := exec.Command( + h.daemon, + "--tun=userspace-networking", + "--state="+filepath.Join(dir, "state.json"), + "--socket="+filepath.Join(dir, "sock"), + fmt.Sprintf("--socks5-server=localhost:%d", port), + ) + + cmd.Env = append( + os.Environ(), + "NOTIFY_SOCKET="+filepath.Join(dir, "notify_socket"), + "TS_LOG_TARGET="+h.loginServerURL, + ) + + err = cmd.Start() + if err != nil { + t.Fatalf("can't start tailscaled: %v", err) + } + + t.Cleanup(func() { + cmd.Process.Kill() + }) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ticker := time.NewTicker(100 * time.Millisecond) + +outer: + for { + select { + case <-ctx.Done(): + t.Fatal("timed out waiting for tailscaled to come up") + return + case <-ticker.C: + conn, err := net.Dial("unix", filepath.Join(dir, "sock")) + if err != nil { + continue + } + + conn.Close() + break outer + } + } + + run(t, dir, h.cli, + "--socket="+filepath.Join(dir, "sock"), + "up", + "--login-server="+controlURL, + "--hostname=tester", + ) + + dialer, err := proxy.SOCKS5("tcp", net.JoinHostPort("127.0.0.1", fmt.Sprint(port)), nil, &net.Dialer{}) + if err != nil { + t.Fatalf("can't make netstack proxy dialer: %v", err) + } + h.testerDialer = dialer + h.testerV4 = bytes2Netaddr(h.Tailscale(t, "ip", "-4")) +} + +func bytes2Netaddr(inp []byte) netip.Addr { + return netip.MustParseAddr(string(bytes.TrimSpace(inp))) +} diff --git a/tstest/integration/vms/nixos_test.go b/tstest/integration/vms/nixos_test.go index 06a14e4f6..c2998ff3c 100644 --- a/tstest/integration/vms/nixos_test.go +++ b/tstest/integration/vms/nixos_test.go @@ -1,231 +1,231 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !plan9 - -package vms - -import ( - "flag" - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" - "testing" - "text/template" - - "tailscale.com/types/logger" -) - -var ( - verboseNixOutput = flag.Bool("verbose-nix-output", false, "if set, use verbose nix output (lots of noise)") -) - -/* - NOTE(Xe): Okay, so, at a high level testing NixOS is a lot different than - other distros due to NixOS' determinism. Normally NixOS wants packages to - be defined in either an overlay, a custom packageOverrides or even - yolo-inline as a part of the system configuration. This is going to have - us take a different approach compared to other distributions. The overall - plan here is as following: - - 1. make the binaries as normal - 2. template in their paths as raw strings to the nixos system module - 3. run `nixos-generators -f qcow -o $CACHE_DIR/tailscale/nixos/version -c generated-config.nix` - 4. pass that to the steps that make the virtual machine - - It doesn't really make sense for us to use a premade virtual machine image - for this as that will make it harder to deterministically create the image. -*/ - -const nixosConfigTemplate = ` -# NOTE(Xe): This template is going to be heavily commented. - -# All NixOS modules are functions. Here is the function prelude for this NixOS -# module that defines the system. It is a function that takes in an attribute -# set (effectively a map[string]nix.Value) and destructures it to some variables: -{ - # other NixOS settings as defined in other modules - config, - - # nixpkgs, which is basically the standard library of NixOS - pkgs, - - # the path to some system-scoped NixOS modules that aren't imported by default - modulesPath, - - # the rest of the arguments don't matter - ... -}: - -# Nix's syntax was inspired by Haskell and other functional languages, so the -# let .. in pattern is used to create scoped variables: -let - # Define the package (derivation) for Tailscale based on the binaries we - # just built for this test: - testTailscale = pkgs.stdenv.mkDerivation { - # The name of the package. This usually includes a version however it - # doesn't matter here. - name = "tailscale-test"; - - # The path on disk to the "source code" of the package, in this case it is - # the path to the binaries that are built. This needs to be the raw - # unquoted slash-separated path, not a string containing the path because Nix - # has a special path type. - src = {{.BinPath}}; - - # We only need to worry about the install phase because we've already - # built the binaries. - phases = "installPhase"; - - # We need to wrap tailscaled such that it has iptables in its $PATH. - nativeBuildInputs = [ pkgs.makeWrapper ]; - - # The install instructions for this package ('' ''defines a multi-line string). - # The with statement lets us bring in values into scope as if they were - # defined in the current scope. - installPhase = with pkgs; '' - # This is bash. - - # Make the output folders for the package (systemd unit and binary folders). - mkdir -p $out/bin - - # Install tailscale{,d} - cp $src/tailscale $out/bin/tailscale - cp $src/tailscaled $out/bin/tailscaled - - # Wrap tailscaled with the ip and iptables commands. - wrapProgram $out/bin/tailscaled --prefix PATH : ${ - lib.makeBinPath [ iproute iptables ] - } - - # Install systemd unit. - cp $src/systemd/tailscaled.service . - sed -i -e "s#/usr/sbin#$out/bin#" -e "/^EnvironmentFile/d" ./tailscaled.service - install -D -m0444 -t $out/lib/systemd/system ./tailscaled.service - ''; - }; -in { - # This is a QEMU VM. This module has a lot of common qemu VM settings so you - # don't have to set them manually. - imports = [ (modulesPath + "/profiles/qemu-guest.nix") ]; - - # We need virtio support to boot. - boot.initrd.availableKernelModules = - [ "ata_piix" "uhci_hcd" "virtio_pci" "sr_mod" "virtio_blk" ]; - boot.initrd.kernelModules = [ ]; - boot.kernelModules = [ ]; - boot.extraModulePackages = [ ]; - - # Curl is needed for one of the steps in cloud-final - systemd.services.cloud-final.path = with pkgs; [ curl ]; - - # Curl is needed for one of the integration tests - environment.systemPackages = with pkgs; [ curl nix bash squid openssl daemonize ]; - - # yolo, this vm can sudo freely. - security.sudo.wheelNeedsPassword = false; - - # Enable cloud-init so we can set VM hostnames and the like the same as other - # distros. This will also take care of SSH keys. It's pretty handy. - services.cloud-init = { - enable = true; - ext4.enable = true; - }; - - # We want sshd running. - services.openssh.enable = true; - - # Tailscale settings: - services.tailscale = { - # We want Tailscale to start at boot. - enable = true; - - # Use the Tailscale package we just assembled. - package = testTailscale; - }; - - # Override TS_LOG_TARGET to our private logcatcher. - systemd.services.tailscaled.environment."TS_LOG_TARGET" = "{{.LogTarget}}"; -}` - -func (h *Harness) copyUnit(t *testing.T) { - t.Helper() - - data, err := os.ReadFile("../../../cmd/tailscaled/tailscaled.service") - if err != nil { - t.Fatal(err) - } - os.MkdirAll(filepath.Join(h.binaryDir, "systemd"), 0755) - err = os.WriteFile(filepath.Join(h.binaryDir, "systemd", "tailscaled.service"), data, 0666) - if err != nil { - t.Fatal(err) - } -} - -func (h *Harness) makeNixOSImage(t *testing.T, d Distro, cdir string) string { - if d.Name == "nixos-unstable" { - t.Skip("https://github.com/NixOS/nixpkgs/issues/131098") - } - - h.copyUnit(t) - dir := t.TempDir() - fname := filepath.Join(dir, d.Name+".nix") - fout, err := os.Create(fname) - if err != nil { - t.Fatal(err) - } - - tmpl := template.Must(template.New("base.nix").Parse(nixosConfigTemplate)) - err = tmpl.Execute(fout, struct { - BinPath string - LogTarget string - }{ - BinPath: h.binaryDir, - LogTarget: h.loginServerURL, - }) - if err != nil { - t.Fatal(err) - } - - err = fout.Close() - if err != nil { - t.Fatal(err) - } - - outpath := filepath.Join(cdir, "nixos") - os.MkdirAll(outpath, 0755) - - t.Cleanup(func() { - os.RemoveAll(filepath.Join(outpath, d.Name)) // makes the disk image a candidate for GC - }) - - cmd := exec.Command("nixos-generate", "-f", "qcow", "-o", filepath.Join(outpath, d.Name), "-c", fname) - if *verboseNixOutput { - cmd.Stdout = logger.FuncWriter(t.Logf) - cmd.Stderr = logger.FuncWriter(t.Logf) - } else { - fname := fmt.Sprintf("nix-build-%s-%s", os.Getenv("GITHUB_RUN_NUMBER"), strings.Replace(t.Name(), "/", "-", -1)) - t.Logf("writing nix logs to %s", fname) - fout, err := os.Create(fname) - if err != nil { - t.Fatalf("can't make log file for nix build: %v", err) - } - cmd.Stdout = fout - cmd.Stderr = fout - defer fout.Close() - } - cmd.Env = append(os.Environ(), "NIX_PATH=nixpkgs="+d.URL) - cmd.Dir = outpath - t.Logf("running %s %#v", "nixos-generate", cmd.Args) - if err := cmd.Run(); err != nil { - t.Fatalf("error while making NixOS image for %s: %v", d.Name, err) - } - - if !*verboseNixOutput { - t.Log("done") - } - - return filepath.Join(outpath, d.Name, "nixos.qcow2") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !plan9 + +package vms + +import ( + "flag" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "text/template" + + "tailscale.com/types/logger" +) + +var ( + verboseNixOutput = flag.Bool("verbose-nix-output", false, "if set, use verbose nix output (lots of noise)") +) + +/* + NOTE(Xe): Okay, so, at a high level testing NixOS is a lot different than + other distros due to NixOS' determinism. Normally NixOS wants packages to + be defined in either an overlay, a custom packageOverrides or even + yolo-inline as a part of the system configuration. This is going to have + us take a different approach compared to other distributions. The overall + plan here is as following: + + 1. make the binaries as normal + 2. template in their paths as raw strings to the nixos system module + 3. run `nixos-generators -f qcow -o $CACHE_DIR/tailscale/nixos/version -c generated-config.nix` + 4. pass that to the steps that make the virtual machine + + It doesn't really make sense for us to use a premade virtual machine image + for this as that will make it harder to deterministically create the image. +*/ + +const nixosConfigTemplate = ` +# NOTE(Xe): This template is going to be heavily commented. + +# All NixOS modules are functions. Here is the function prelude for this NixOS +# module that defines the system. It is a function that takes in an attribute +# set (effectively a map[string]nix.Value) and destructures it to some variables: +{ + # other NixOS settings as defined in other modules + config, + + # nixpkgs, which is basically the standard library of NixOS + pkgs, + + # the path to some system-scoped NixOS modules that aren't imported by default + modulesPath, + + # the rest of the arguments don't matter + ... +}: + +# Nix's syntax was inspired by Haskell and other functional languages, so the +# let .. in pattern is used to create scoped variables: +let + # Define the package (derivation) for Tailscale based on the binaries we + # just built for this test: + testTailscale = pkgs.stdenv.mkDerivation { + # The name of the package. This usually includes a version however it + # doesn't matter here. + name = "tailscale-test"; + + # The path on disk to the "source code" of the package, in this case it is + # the path to the binaries that are built. This needs to be the raw + # unquoted slash-separated path, not a string containing the path because Nix + # has a special path type. + src = {{.BinPath}}; + + # We only need to worry about the install phase because we've already + # built the binaries. + phases = "installPhase"; + + # We need to wrap tailscaled such that it has iptables in its $PATH. + nativeBuildInputs = [ pkgs.makeWrapper ]; + + # The install instructions for this package ('' ''defines a multi-line string). + # The with statement lets us bring in values into scope as if they were + # defined in the current scope. + installPhase = with pkgs; '' + # This is bash. + + # Make the output folders for the package (systemd unit and binary folders). + mkdir -p $out/bin + + # Install tailscale{,d} + cp $src/tailscale $out/bin/tailscale + cp $src/tailscaled $out/bin/tailscaled + + # Wrap tailscaled with the ip and iptables commands. + wrapProgram $out/bin/tailscaled --prefix PATH : ${ + lib.makeBinPath [ iproute iptables ] + } + + # Install systemd unit. + cp $src/systemd/tailscaled.service . + sed -i -e "s#/usr/sbin#$out/bin#" -e "/^EnvironmentFile/d" ./tailscaled.service + install -D -m0444 -t $out/lib/systemd/system ./tailscaled.service + ''; + }; +in { + # This is a QEMU VM. This module has a lot of common qemu VM settings so you + # don't have to set them manually. + imports = [ (modulesPath + "/profiles/qemu-guest.nix") ]; + + # We need virtio support to boot. + boot.initrd.availableKernelModules = + [ "ata_piix" "uhci_hcd" "virtio_pci" "sr_mod" "virtio_blk" ]; + boot.initrd.kernelModules = [ ]; + boot.kernelModules = [ ]; + boot.extraModulePackages = [ ]; + + # Curl is needed for one of the steps in cloud-final + systemd.services.cloud-final.path = with pkgs; [ curl ]; + + # Curl is needed for one of the integration tests + environment.systemPackages = with pkgs; [ curl nix bash squid openssl daemonize ]; + + # yolo, this vm can sudo freely. + security.sudo.wheelNeedsPassword = false; + + # Enable cloud-init so we can set VM hostnames and the like the same as other + # distros. This will also take care of SSH keys. It's pretty handy. + services.cloud-init = { + enable = true; + ext4.enable = true; + }; + + # We want sshd running. + services.openssh.enable = true; + + # Tailscale settings: + services.tailscale = { + # We want Tailscale to start at boot. + enable = true; + + # Use the Tailscale package we just assembled. + package = testTailscale; + }; + + # Override TS_LOG_TARGET to our private logcatcher. + systemd.services.tailscaled.environment."TS_LOG_TARGET" = "{{.LogTarget}}"; +}` + +func (h *Harness) copyUnit(t *testing.T) { + t.Helper() + + data, err := os.ReadFile("../../../cmd/tailscaled/tailscaled.service") + if err != nil { + t.Fatal(err) + } + os.MkdirAll(filepath.Join(h.binaryDir, "systemd"), 0755) + err = os.WriteFile(filepath.Join(h.binaryDir, "systemd", "tailscaled.service"), data, 0666) + if err != nil { + t.Fatal(err) + } +} + +func (h *Harness) makeNixOSImage(t *testing.T, d Distro, cdir string) string { + if d.Name == "nixos-unstable" { + t.Skip("https://github.com/NixOS/nixpkgs/issues/131098") + } + + h.copyUnit(t) + dir := t.TempDir() + fname := filepath.Join(dir, d.Name+".nix") + fout, err := os.Create(fname) + if err != nil { + t.Fatal(err) + } + + tmpl := template.Must(template.New("base.nix").Parse(nixosConfigTemplate)) + err = tmpl.Execute(fout, struct { + BinPath string + LogTarget string + }{ + BinPath: h.binaryDir, + LogTarget: h.loginServerURL, + }) + if err != nil { + t.Fatal(err) + } + + err = fout.Close() + if err != nil { + t.Fatal(err) + } + + outpath := filepath.Join(cdir, "nixos") + os.MkdirAll(outpath, 0755) + + t.Cleanup(func() { + os.RemoveAll(filepath.Join(outpath, d.Name)) // makes the disk image a candidate for GC + }) + + cmd := exec.Command("nixos-generate", "-f", "qcow", "-o", filepath.Join(outpath, d.Name), "-c", fname) + if *verboseNixOutput { + cmd.Stdout = logger.FuncWriter(t.Logf) + cmd.Stderr = logger.FuncWriter(t.Logf) + } else { + fname := fmt.Sprintf("nix-build-%s-%s", os.Getenv("GITHUB_RUN_NUMBER"), strings.Replace(t.Name(), "/", "-", -1)) + t.Logf("writing nix logs to %s", fname) + fout, err := os.Create(fname) + if err != nil { + t.Fatalf("can't make log file for nix build: %v", err) + } + cmd.Stdout = fout + cmd.Stderr = fout + defer fout.Close() + } + cmd.Env = append(os.Environ(), "NIX_PATH=nixpkgs="+d.URL) + cmd.Dir = outpath + t.Logf("running %s %#v", "nixos-generate", cmd.Args) + if err := cmd.Run(); err != nil { + t.Fatalf("error while making NixOS image for %s: %v", d.Name, err) + } + + if !*verboseNixOutput { + t.Log("done") + } + + return filepath.Join(outpath, d.Name, "nixos.qcow2") +} diff --git a/tstest/integration/vms/regex_flag.go b/tstest/integration/vms/regex_flag.go index 195f7c771..02e399ecd 100644 --- a/tstest/integration/vms/regex_flag.go +++ b/tstest/integration/vms/regex_flag.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package vms - -import "regexp" - -type regexValue struct { - r *regexp.Regexp -} - -func (r *regexValue) String() string { - if r.r == nil { - return "" - } - - return r.r.String() -} - -func (r *regexValue) Set(val string) error { - if rex, err := regexp.Compile(val); err != nil { - return err - } else { - r.r = rex - return nil - } -} - -func (r regexValue) Unwrap() *regexp.Regexp { return r.r } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package vms + +import "regexp" + +type regexValue struct { + r *regexp.Regexp +} + +func (r *regexValue) String() string { + if r.r == nil { + return "" + } + + return r.r.String() +} + +func (r *regexValue) Set(val string) error { + if rex, err := regexp.Compile(val); err != nil { + return err + } else { + r.r = rex + return nil + } +} + +func (r regexValue) Unwrap() *regexp.Regexp { return r.r } diff --git a/tstest/integration/vms/regex_flag_test.go b/tstest/integration/vms/regex_flag_test.go index 790894080..0f4e5f8f7 100644 --- a/tstest/integration/vms/regex_flag_test.go +++ b/tstest/integration/vms/regex_flag_test.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package vms - -import ( - "flag" - "testing" -) - -func TestRegexFlag(t *testing.T) { - var v regexValue - fs := flag.NewFlagSet(t.Name(), flag.PanicOnError) - fs.Var(&v, "regex", "regex to parse") - - const want = `.*` - fs.Parse([]string{"-regex", want}) - if v.Unwrap().String() != want { - t.Fatalf("got wrong regex: %q, wanted: %q", v.Unwrap().String(), want) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package vms + +import ( + "flag" + "testing" +) + +func TestRegexFlag(t *testing.T) { + var v regexValue + fs := flag.NewFlagSet(t.Name(), flag.PanicOnError) + fs.Var(&v, "regex", "regex to parse") + + const want = `.*` + fs.Parse([]string{"-regex", want}) + if v.Unwrap().String() != want { + t.Fatalf("got wrong regex: %q, wanted: %q", v.Unwrap().String(), want) + } +} diff --git a/tstest/integration/vms/runner.nix b/tstest/integration/vms/runner.nix index 8d4c0a25d..ac569cf65 100644 --- a/tstest/integration/vms/runner.nix +++ b/tstest/integration/vms/runner.nix @@ -1,89 +1,89 @@ -# This is a NixOS module to allow a machine to act as an integration test -# runner. This is used for the end-to-end VM test suite. - -{ lib, config, pkgs, ... }: - -{ - # The GitHub Actions self-hosted runner service. - services.github-runner = { - enable = true; - url = "https://github.com/tailscale/tailscale"; - replace = true; - extraLabels = [ "vm_integration_test" ]; - - # Justifications for the packages: - extraPackages = with pkgs; [ - # The test suite is written in Go. - go - - # This contains genisoimage, which is needed to create cloud-init - # seeds. - cdrkit - - # This package is the virtual machine hypervisor we use in tests. - qemu - - # This package contains tools like `ssh-keygen`. - openssh - - # The C compiler so cgo builds work. - gcc - - # The package manager Nix, just in case. - nix - - # Used to generate a NixOS image for testing. - nixos-generators - - # Used to extract things. - gnutar - - # Used to decompress things. - lzma - ]; - - # Customize this to include your GitHub username so we can track - # who is running which node. - name = "YOUR-GITHUB-USERNAME-tstest-integration-vms"; - - # Replace this with the path to the GitHub Actions runner token on - # your disk. - tokenFile = "/run/decrypted/ts-oss-ghaction-token"; - }; - - # A user account so there is a home directory and so they have kvm - # access. Please don't change this account name. - users.users.ghrunner = { - createHome = true; - isSystemUser = true; - extraGroups = [ "kvm" ]; - }; - - # The default github-runner service sets a lot of isolation features - # that attempt to limit the damage that malicious code can use. - # Unfortunately we rely on some "dangerous" features to do these tests, - # so this shim will peel some of them away. - systemd.services.github-runner = { - serviceConfig = { - # We need access to /dev to poke /dev/kvm. - PrivateDevices = lib.mkForce false; - - # /dev/kvm is how qemu creates a virtual machine with KVM. - DeviceAllow = lib.mkForce [ "/dev/kvm" ]; - - # Ensure the service has KVM permissions with the `kvm` group. - ExtraGroups = [ "kvm" ]; - - # The service runs as a dynamic user by default. This makes it hard - # to persistently store things in /var/lib/ghrunner. This line - # disables the dynamic user feature. - DynamicUser = lib.mkForce false; - - # Run this service as our ghrunner user. - User = "ghrunner"; - - # We need access to /var/lib/ghrunner to store VM images. - ProtectSystem = lib.mkForce null; - }; - }; -} +# This is a NixOS module to allow a machine to act as an integration test +# runner. This is used for the end-to-end VM test suite. + +{ lib, config, pkgs, ... }: + +{ + # The GitHub Actions self-hosted runner service. + services.github-runner = { + enable = true; + url = "https://github.com/tailscale/tailscale"; + replace = true; + extraLabels = [ "vm_integration_test" ]; + + # Justifications for the packages: + extraPackages = with pkgs; [ + # The test suite is written in Go. + go + + # This contains genisoimage, which is needed to create cloud-init + # seeds. + cdrkit + + # This package is the virtual machine hypervisor we use in tests. + qemu + + # This package contains tools like `ssh-keygen`. + openssh + + # The C compiler so cgo builds work. + gcc + + # The package manager Nix, just in case. + nix + + # Used to generate a NixOS image for testing. + nixos-generators + + # Used to extract things. + gnutar + + # Used to decompress things. + lzma + ]; + + # Customize this to include your GitHub username so we can track + # who is running which node. + name = "YOUR-GITHUB-USERNAME-tstest-integration-vms"; + + # Replace this with the path to the GitHub Actions runner token on + # your disk. + tokenFile = "/run/decrypted/ts-oss-ghaction-token"; + }; + + # A user account so there is a home directory and so they have kvm + # access. Please don't change this account name. + users.users.ghrunner = { + createHome = true; + isSystemUser = true; + extraGroups = [ "kvm" ]; + }; + + # The default github-runner service sets a lot of isolation features + # that attempt to limit the damage that malicious code can use. + # Unfortunately we rely on some "dangerous" features to do these tests, + # so this shim will peel some of them away. + systemd.services.github-runner = { + serviceConfig = { + # We need access to /dev to poke /dev/kvm. + PrivateDevices = lib.mkForce false; + + # /dev/kvm is how qemu creates a virtual machine with KVM. + DeviceAllow = lib.mkForce [ "/dev/kvm" ]; + + # Ensure the service has KVM permissions with the `kvm` group. + ExtraGroups = [ "kvm" ]; + + # The service runs as a dynamic user by default. This makes it hard + # to persistently store things in /var/lib/ghrunner. This line + # disables the dynamic user feature. + DynamicUser = lib.mkForce false; + + # Run this service as our ghrunner user. + User = "ghrunner"; + + # We need access to /var/lib/ghrunner to store VM images. + ProtectSystem = lib.mkForce null; + }; + }; +} diff --git a/tstest/integration/vms/squid.conf b/tstest/integration/vms/squid.conf index e43c5cd1f..29d32bd6d 100644 --- a/tstest/integration/vms/squid.conf +++ b/tstest/integration/vms/squid.conf @@ -1,39 +1,39 @@ -pid_filename /run/squid.pid -cache_dir ufs /tmp/squid/cache 500 16 256 -maximum_object_size 4096 KB -coredump_dir /tmp/squid/core -visible_hostname localhost -cache_access_log /tmp/squid/access.log -cache_log /tmp/squid/cache.log - -# Access Control lists -acl localhost src 127.0.0.1 ::1 -acl manager proto cache_object -acl SSL_ports port 443 -acl Safe_ports port 80 # http -acl Safe_ports port 21 # ftp -acl Safe_ports port 443 # https -acl Safe_ports port 70 # gopher -acl Safe_ports port 210 # wais -acl Safe_ports port 1025-65535 # unregistered ports -acl Safe_ports port 280 # http-mgmt -acl Safe_ports port 488 # gss-http -acl Safe_ports port 591 # filemaker -acl Safe_ports port 777 # multiling http -acl CONNECT method CONNECT - -http_access allow localhost -http_access deny all -forwarded_for on - -# sslcrtd_program /nix/store/nqlqk1f6qlxdirlrl1aijgb6vbzxs0gs-squid-4.17/libexec/security_file_certgen -s /tmp/squid/ssl_db -M 4MB -sslcrtd_children 5 - -http_port 127.0.0.1:3128 \ - ssl-bump \ - generate-host-certificates=on \ - dynamic_cert_mem_cache_size=4MB \ - cert=/tmp/squid/myca-mitm.pem - -ssl_bump stare all # mimic the Client Hello, drop unsupported extensions +pid_filename /run/squid.pid +cache_dir ufs /tmp/squid/cache 500 16 256 +maximum_object_size 4096 KB +coredump_dir /tmp/squid/core +visible_hostname localhost +cache_access_log /tmp/squid/access.log +cache_log /tmp/squid/cache.log + +# Access Control lists +acl localhost src 127.0.0.1 ::1 +acl manager proto cache_object +acl SSL_ports port 443 +acl Safe_ports port 80 # http +acl Safe_ports port 21 # ftp +acl Safe_ports port 443 # https +acl Safe_ports port 70 # gopher +acl Safe_ports port 210 # wais +acl Safe_ports port 1025-65535 # unregistered ports +acl Safe_ports port 280 # http-mgmt +acl Safe_ports port 488 # gss-http +acl Safe_ports port 591 # filemaker +acl Safe_ports port 777 # multiling http +acl CONNECT method CONNECT + +http_access allow localhost +http_access deny all +forwarded_for on + +# sslcrtd_program /nix/store/nqlqk1f6qlxdirlrl1aijgb6vbzxs0gs-squid-4.17/libexec/security_file_certgen -s /tmp/squid/ssl_db -M 4MB +sslcrtd_children 5 + +http_port 127.0.0.1:3128 \ + ssl-bump \ + generate-host-certificates=on \ + dynamic_cert_mem_cache_size=4MB \ + cert=/tmp/squid/myca-mitm.pem + +ssl_bump stare all # mimic the Client Hello, drop unsupported extensions ssl_bump bump all # terminate and establish new TLS connection \ No newline at end of file diff --git a/tstest/integration/vms/top_level_test.go b/tstest/integration/vms/top_level_test.go index 1b9c10e29..c107fd89c 100644 --- a/tstest/integration/vms/top_level_test.go +++ b/tstest/integration/vms/top_level_test.go @@ -1,124 +1,124 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !plan9 - -package vms - -import ( - "context" - "testing" - "time" - - "github.com/pkg/sftp" - expect "github.com/tailscale/goexpect" -) - -func TestRunUbuntu1804(t *testing.T) { - testOneDistribution(t, 0, Distros[0]) -} - -func TestRunUbuntu2004(t *testing.T) { - testOneDistribution(t, 1, Distros[1]) -} - -func TestRunNixos2111(t *testing.T) { - t.Parallel() - testOneDistribution(t, 2, Distros[2]) -} - -// TestMITMProxy is a smoke test for derphttp through a MITM proxy. -// Encountering such proxies is unfortunately commonplace in more -// traditional enterprise networks. -// -// We invoke tailscale netcheck because the networking check is done -// by tailscale rather than tailscaled, making it easier to configure -// the proxy. -// -// To provide the actual MITM server, we use squid. -func TestMITMProxy(t *testing.T) { - t.Parallel() - setupTests(t) - distro := Distros[2] // nixos-21.11 - - if distroRex.Unwrap().MatchString(distro.Name) { - t.Logf("%s matches %s", distro.Name, distroRex.Unwrap()) - } else { - t.Skip("regex not matched") - } - - ctx, done := context.WithCancel(context.Background()) - t.Cleanup(done) - - h := newHarness(t) - - err := ramsem.sem.Acquire(ctx, int64(distro.MemoryMegs)) - if err != nil { - t.Fatalf("can't acquire ram semaphore: %v", err) - } - t.Cleanup(func() { ramsem.sem.Release(int64(distro.MemoryMegs)) }) - - vm := h.mkVM(t, 2, distro, h.pubKey, h.loginServerURL, t.TempDir()) - vm.waitStartup(t) - - ipm := h.waitForIPMap(t, vm, distro) - _, cli := h.setupSSHShell(t, distro, ipm) - - sftpCli, err := sftp.NewClient(cli) - if err != nil { - t.Fatalf("can't connect over sftp to copy binaries: %v", err) - } - defer sftpCli.Close() - - // Initialize a squid installation. - // - // A few things of note here: - // - The first thing we do is append the nsslcrtd_program stanza to the config. - // This must be an absolute path and is based on the nix path of the squid derivation, - // so we compute and write it out here. - // - Squid expects a pre-initialized directory layout, so we create that in /tmp/squid then - // invoke squid with -z to have it fill in the rest. - // - Doing a meddler-in-the-middle attack requires using some fake keys, so we create - // them using openssl and then use the security_file_certgen tool to setup squids' ssl_db. - // - There were some perms issues, so i yeeted 0777. Its only a test anyway - copyFile(t, sftpCli, "squid.conf", "/tmp/squid.conf") - runTestCommands(t, 30*time.Second, cli, []expect.Batcher{ - &expect.BSnd{S: "echo -e \"\\nsslcrtd_program $(nix eval --raw nixpkgs.squid)/libexec/security_file_certgen -s /tmp/squid/ssl_db -M 4MB\\n\" >> /tmp/squid.conf\n"}, - &expect.BSnd{S: "mkdir -p /tmp/squid/{cache,core}\n"}, - &expect.BSnd{S: "openssl req -batch -new -newkey rsa:4096 -sha256 -days 3650 -nodes -x509 -keyout /tmp/squid/myca-mitm.pem -out /tmp/squid/myca-mitm.pem\n"}, - &expect.BExp{R: `writing new private key to '/tmp/squid/myca-mitm.pem'`}, - &expect.BSnd{S: "$(nix eval --raw nixpkgs.squid)/libexec/security_file_certgen -c -s /tmp/squid/ssl_db -M 4MB\n"}, - &expect.BExp{R: `Done`}, - &expect.BSnd{S: "sudo chmod -R 0777 /tmp/squid\n"}, - &expect.BSnd{S: "squid --foreground -YCs -z -f /tmp/squid.conf\n"}, - &expect.BSnd{S: "echo Success.\n"}, - &expect.BExp{R: `Success.`}, - }) - - // Start the squid server. - runTestCommands(t, 10*time.Second, cli, []expect.Batcher{ - &expect.BSnd{S: "daemonize -v -c /tmp/squid $(nix eval --raw nixpkgs.squid)/bin/squid --foreground -YCs -f /tmp/squid.conf\n"}, // start daemon - // NOTE(tom): Writing to /dev/tcp/* is bash magic, not a file. This - // eldritchian incantation lets us wait till squid is up. - &expect.BSnd{S: "while ! timeout 5 bash -c 'echo > /dev/tcp/localhost/3128'; do sleep 1; done\n"}, - &expect.BSnd{S: "echo Success.\n"}, - &expect.BExp{R: `Success.`}, - }) - - // Uncomment to help debugging this test if it fails. - // - // runTestCommands(t, 30 * time.Second, cli, []expect.Batcher{ - // &expect.BSnd{S: "sudo ifconfig\n"}, - // &expect.BSnd{S: "sudo ip link\n"}, - // &expect.BSnd{S: "sudo ip route\n"}, - // &expect.BSnd{S: "ps -aux\n"}, - // &expect.BSnd{S: "netstat -a\n"}, - // &expect.BSnd{S: "cat /tmp/squid/access.log && cat /tmp/squid/cache.log && cat /tmp/squid.conf && echo Success.\n"}, - // &expect.BExp{R: `Success.`}, - // }) - - runTestCommands(t, 30*time.Second, cli, []expect.Batcher{ - &expect.BSnd{S: "SSL_CERT_FILE=/tmp/squid/myca-mitm.pem HTTPS_PROXY=http://127.0.0.1:3128 tailscale netcheck\n"}, - &expect.BExp{R: `IPv4: yes`}, - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !plan9 + +package vms + +import ( + "context" + "testing" + "time" + + "github.com/pkg/sftp" + expect "github.com/tailscale/goexpect" +) + +func TestRunUbuntu1804(t *testing.T) { + testOneDistribution(t, 0, Distros[0]) +} + +func TestRunUbuntu2004(t *testing.T) { + testOneDistribution(t, 1, Distros[1]) +} + +func TestRunNixos2111(t *testing.T) { + t.Parallel() + testOneDistribution(t, 2, Distros[2]) +} + +// TestMITMProxy is a smoke test for derphttp through a MITM proxy. +// Encountering such proxies is unfortunately commonplace in more +// traditional enterprise networks. +// +// We invoke tailscale netcheck because the networking check is done +// by tailscale rather than tailscaled, making it easier to configure +// the proxy. +// +// To provide the actual MITM server, we use squid. +func TestMITMProxy(t *testing.T) { + t.Parallel() + setupTests(t) + distro := Distros[2] // nixos-21.11 + + if distroRex.Unwrap().MatchString(distro.Name) { + t.Logf("%s matches %s", distro.Name, distroRex.Unwrap()) + } else { + t.Skip("regex not matched") + } + + ctx, done := context.WithCancel(context.Background()) + t.Cleanup(done) + + h := newHarness(t) + + err := ramsem.sem.Acquire(ctx, int64(distro.MemoryMegs)) + if err != nil { + t.Fatalf("can't acquire ram semaphore: %v", err) + } + t.Cleanup(func() { ramsem.sem.Release(int64(distro.MemoryMegs)) }) + + vm := h.mkVM(t, 2, distro, h.pubKey, h.loginServerURL, t.TempDir()) + vm.waitStartup(t) + + ipm := h.waitForIPMap(t, vm, distro) + _, cli := h.setupSSHShell(t, distro, ipm) + + sftpCli, err := sftp.NewClient(cli) + if err != nil { + t.Fatalf("can't connect over sftp to copy binaries: %v", err) + } + defer sftpCli.Close() + + // Initialize a squid installation. + // + // A few things of note here: + // - The first thing we do is append the nsslcrtd_program stanza to the config. + // This must be an absolute path and is based on the nix path of the squid derivation, + // so we compute and write it out here. + // - Squid expects a pre-initialized directory layout, so we create that in /tmp/squid then + // invoke squid with -z to have it fill in the rest. + // - Doing a meddler-in-the-middle attack requires using some fake keys, so we create + // them using openssl and then use the security_file_certgen tool to setup squids' ssl_db. + // - There were some perms issues, so i yeeted 0777. Its only a test anyway + copyFile(t, sftpCli, "squid.conf", "/tmp/squid.conf") + runTestCommands(t, 30*time.Second, cli, []expect.Batcher{ + &expect.BSnd{S: "echo -e \"\\nsslcrtd_program $(nix eval --raw nixpkgs.squid)/libexec/security_file_certgen -s /tmp/squid/ssl_db -M 4MB\\n\" >> /tmp/squid.conf\n"}, + &expect.BSnd{S: "mkdir -p /tmp/squid/{cache,core}\n"}, + &expect.BSnd{S: "openssl req -batch -new -newkey rsa:4096 -sha256 -days 3650 -nodes -x509 -keyout /tmp/squid/myca-mitm.pem -out /tmp/squid/myca-mitm.pem\n"}, + &expect.BExp{R: `writing new private key to '/tmp/squid/myca-mitm.pem'`}, + &expect.BSnd{S: "$(nix eval --raw nixpkgs.squid)/libexec/security_file_certgen -c -s /tmp/squid/ssl_db -M 4MB\n"}, + &expect.BExp{R: `Done`}, + &expect.BSnd{S: "sudo chmod -R 0777 /tmp/squid\n"}, + &expect.BSnd{S: "squid --foreground -YCs -z -f /tmp/squid.conf\n"}, + &expect.BSnd{S: "echo Success.\n"}, + &expect.BExp{R: `Success.`}, + }) + + // Start the squid server. + runTestCommands(t, 10*time.Second, cli, []expect.Batcher{ + &expect.BSnd{S: "daemonize -v -c /tmp/squid $(nix eval --raw nixpkgs.squid)/bin/squid --foreground -YCs -f /tmp/squid.conf\n"}, // start daemon + // NOTE(tom): Writing to /dev/tcp/* is bash magic, not a file. This + // eldritchian incantation lets us wait till squid is up. + &expect.BSnd{S: "while ! timeout 5 bash -c 'echo > /dev/tcp/localhost/3128'; do sleep 1; done\n"}, + &expect.BSnd{S: "echo Success.\n"}, + &expect.BExp{R: `Success.`}, + }) + + // Uncomment to help debugging this test if it fails. + // + // runTestCommands(t, 30 * time.Second, cli, []expect.Batcher{ + // &expect.BSnd{S: "sudo ifconfig\n"}, + // &expect.BSnd{S: "sudo ip link\n"}, + // &expect.BSnd{S: "sudo ip route\n"}, + // &expect.BSnd{S: "ps -aux\n"}, + // &expect.BSnd{S: "netstat -a\n"}, + // &expect.BSnd{S: "cat /tmp/squid/access.log && cat /tmp/squid/cache.log && cat /tmp/squid.conf && echo Success.\n"}, + // &expect.BExp{R: `Success.`}, + // }) + + runTestCommands(t, 30*time.Second, cli, []expect.Batcher{ + &expect.BSnd{S: "SSL_CERT_FILE=/tmp/squid/myca-mitm.pem HTTPS_PROXY=http://127.0.0.1:3128 tailscale netcheck\n"}, + &expect.BExp{R: `IPv4: yes`}, + }) +} diff --git a/tstest/integration/vms/udp_tester.go b/tstest/integration/vms/udp_tester.go index 14c8c6ed0..be44aa963 100644 --- a/tstest/integration/vms/udp_tester.go +++ b/tstest/integration/vms/udp_tester.go @@ -1,77 +1,77 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ignore - -// Command udp_tester exists because all of these distros being tested don't -// have a consistent tool for doing UDP traffic. This is a very hacked up tool -// that does that UDP traffic so these tests can be done. -package main - -import ( - "flag" - "io" - "log" - "net" - "os" -) - -var ( - client = flag.String("client", "", "host:port to connect to for sending UDP") - server = flag.String("server", "", "host:port to bind to for receiving UDP") -) - -func main() { - flag.Parse() - - if *client == "" && *server == "" { - log.Fatal("specify -client or -server") - } - - if *client != "" { - conn, err := net.Dial("udp", *client) - if err != nil { - log.Fatalf("can't dial %s: %v", *client, err) - } - log.Printf("dialed to %s", conn.RemoteAddr()) - defer conn.Close() - - buf := make([]byte, 2048) - n, err := os.Stdin.Read(buf) - if err != nil && err != io.EOF { - log.Fatalf("can't read from stdin: %v", err) - } - - nn, err := conn.Write(buf[:n]) - if err != nil { - log.Fatalf("can't write to %s: %v", conn.RemoteAddr(), err) - } - - if n == nn { - return - } - - log.Fatalf("wanted to write %d bytes, wrote %d bytes", n, nn) - } - - if *server != "" { - addr, err := net.ResolveUDPAddr("udp", *server) - if err != nil { - log.Fatalf("can't resolve %s: %v", *server, err) - } - ln, err := net.ListenUDP("udp", addr) - if err != nil { - log.Fatalf("can't listen %s: %v", *server, err) - } - defer ln.Close() - - buf := make([]byte, 2048) - - n, _, err := ln.ReadFromUDP(buf) - if err != nil { - log.Fatal(err) - } - - os.Stdout.Write(buf[:n]) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +// Command udp_tester exists because all of these distros being tested don't +// have a consistent tool for doing UDP traffic. This is a very hacked up tool +// that does that UDP traffic so these tests can be done. +package main + +import ( + "flag" + "io" + "log" + "net" + "os" +) + +var ( + client = flag.String("client", "", "host:port to connect to for sending UDP") + server = flag.String("server", "", "host:port to bind to for receiving UDP") +) + +func main() { + flag.Parse() + + if *client == "" && *server == "" { + log.Fatal("specify -client or -server") + } + + if *client != "" { + conn, err := net.Dial("udp", *client) + if err != nil { + log.Fatalf("can't dial %s: %v", *client, err) + } + log.Printf("dialed to %s", conn.RemoteAddr()) + defer conn.Close() + + buf := make([]byte, 2048) + n, err := os.Stdin.Read(buf) + if err != nil && err != io.EOF { + log.Fatalf("can't read from stdin: %v", err) + } + + nn, err := conn.Write(buf[:n]) + if err != nil { + log.Fatalf("can't write to %s: %v", conn.RemoteAddr(), err) + } + + if n == nn { + return + } + + log.Fatalf("wanted to write %d bytes, wrote %d bytes", n, nn) + } + + if *server != "" { + addr, err := net.ResolveUDPAddr("udp", *server) + if err != nil { + log.Fatalf("can't resolve %s: %v", *server, err) + } + ln, err := net.ListenUDP("udp", addr) + if err != nil { + log.Fatalf("can't listen %s: %v", *server, err) + } + defer ln.Close() + + buf := make([]byte, 2048) + + n, _, err := ln.ReadFromUDP(buf) + if err != nil { + log.Fatal(err) + } + + os.Stdout.Write(buf[:n]) + } +} diff --git a/tstest/log_test.go b/tstest/log_test.go index a8cb62cf5..51a5743c2 100644 --- a/tstest/log_test.go +++ b/tstest/log_test.go @@ -1,47 +1,47 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstest - -import ( - "reflect" - "testing" -) - -func TestLogLineTracker(t *testing.T) { - const ( - l1 = "line 1: %s" - l2 = "line 2: %s" - l3 = "line 3: %s" - ) - - lt := NewLogLineTracker(t.Logf, []string{l1, l2}) - - if got, want := lt.Check(), []string{l1, l2}; !reflect.DeepEqual(got, want) { - t.Errorf("Check = %q; want %q", got, want) - } - - lt.Logf(l3, "hi") - - if got, want := lt.Check(), []string{l1, l2}; !reflect.DeepEqual(got, want) { - t.Errorf("Check = %q; want %q", got, want) - } - - lt.Logf(l1, "hi") - - if got, want := lt.Check(), []string{l2}; !reflect.DeepEqual(got, want) { - t.Errorf("Check = %q; want %q", got, want) - } - - lt.Logf(l1, "bye") - - if got, want := lt.Check(), []string{l2}; !reflect.DeepEqual(got, want) { - t.Errorf("Check = %q; want %q", got, want) - } - - lt.Logf(l2, "hi") - - if got, want := lt.Check(), []string(nil); !reflect.DeepEqual(got, want) { - t.Errorf("Check = %q; want %q", got, want) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstest + +import ( + "reflect" + "testing" +) + +func TestLogLineTracker(t *testing.T) { + const ( + l1 = "line 1: %s" + l2 = "line 2: %s" + l3 = "line 3: %s" + ) + + lt := NewLogLineTracker(t.Logf, []string{l1, l2}) + + if got, want := lt.Check(), []string{l1, l2}; !reflect.DeepEqual(got, want) { + t.Errorf("Check = %q; want %q", got, want) + } + + lt.Logf(l3, "hi") + + if got, want := lt.Check(), []string{l1, l2}; !reflect.DeepEqual(got, want) { + t.Errorf("Check = %q; want %q", got, want) + } + + lt.Logf(l1, "hi") + + if got, want := lt.Check(), []string{l2}; !reflect.DeepEqual(got, want) { + t.Errorf("Check = %q; want %q", got, want) + } + + lt.Logf(l1, "bye") + + if got, want := lt.Check(), []string{l2}; !reflect.DeepEqual(got, want) { + t.Errorf("Check = %q; want %q", got, want) + } + + lt.Logf(l2, "hi") + + if got, want := lt.Check(), []string(nil); !reflect.DeepEqual(got, want) { + t.Errorf("Check = %q; want %q", got, want) + } +} diff --git a/tstest/natlab/firewall.go b/tstest/natlab/firewall.go index 851f1c56d..c427d6692 100644 --- a/tstest/natlab/firewall.go +++ b/tstest/natlab/firewall.go @@ -1,156 +1,156 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package natlab - -import ( - "fmt" - "net/netip" - "sync" - "time" - - "tailscale.com/util/mak" -) - -// FirewallType is the type of filtering a stateful firewall -// does. Values express different modes defined by RFC 4787. -type FirewallType int - -const ( - // AddressAndPortDependentFirewall specifies a destination - // address-and-port dependent firewall. Outbound traffic to an - // ip:port authorizes traffic from that ip:port exactly, and - // nothing else. - AddressAndPortDependentFirewall FirewallType = iota - // AddressDependentFirewall specifies a destination address - // dependent firewall. Once outbound traffic has been seen to an - // IP address, that IP address can talk back from any port. - AddressDependentFirewall - // EndpointIndependentFirewall specifies a destination endpoint - // independent firewall. Once outbound traffic has been seen from - // a source, anyone can talk back to that source. - EndpointIndependentFirewall -) - -// fwKey is the lookup key for a firewall session. While it contains a -// 4-tuple ({src,dst} {ip,port}), some FirewallTypes will zero out -// some fields, so in practice the key is either a 2-tuple (src only), -// 3-tuple (src ip+port and dst ip) or 4-tuple (src+dst ip+port). -type fwKey struct { - src netip.AddrPort - dst netip.AddrPort -} - -// key returns an fwKey for the given src and dst, trimmed according -// to the FirewallType. fwKeys are always constructed from the -// "outbound" point of view (i.e. src is the "trusted" side of the -// world), it's the caller's responsibility to swap src and dst in the -// call to key when processing packets inbound from the "untrusted" -// world. -func (s FirewallType) key(src, dst netip.AddrPort) fwKey { - k := fwKey{src: src} - switch s { - case EndpointIndependentFirewall: - case AddressDependentFirewall: - k.dst = netip.AddrPortFrom(dst.Addr(), k.dst.Port()) - case AddressAndPortDependentFirewall: - k.dst = dst - default: - panic(fmt.Sprintf("unknown firewall selectivity %v", s)) - } - return k -} - -// DefaultSessionTimeout is the default timeout for a firewall -// session. -const DefaultSessionTimeout = 30 * time.Second - -// Firewall is a simple stateful firewall that allows all outbound -// traffic and filters inbound traffic based on recently seen outbound -// traffic. Its HandlePacket method should be attached to a Machine to -// give it a stateful firewall. -type Firewall struct { - // SessionTimeout is the lifetime of idle sessions in the firewall - // state. Packets transiting from the TrustedInterface reset the - // session lifetime to SessionTimeout. If zero, - // DefaultSessionTimeout is used. - SessionTimeout time.Duration - // Type specifies how precisely return traffic must match - // previously seen outbound traffic to be allowed. Defaults to - // AddressAndPortDependentFirewall. - Type FirewallType - // TrustedInterface is an optional interface that is considered - // trusted in addition to PacketConns local to the Machine. All - // other interfaces can only respond to traffic from - // TrustedInterface or the local host. - TrustedInterface *Interface - // TimeNow is a function returning the current time. If nil, - // time.Now is used. - TimeNow func() time.Time - - // TODO: refresh directionality: outbound-only, both - - mu sync.Mutex - seen map[fwKey]time.Time // session -> deadline -} - -func (f *Firewall) timeNow() time.Time { - if f.TimeNow != nil { - return f.TimeNow() - } - return time.Now() -} - -// Reset drops all firewall state, forgetting all flows. -func (f *Firewall) Reset() { - f.mu.Lock() - defer f.mu.Unlock() - f.seen = nil -} - -func (f *Firewall) HandleOut(p *Packet, oif *Interface) *Packet { - f.mu.Lock() - defer f.mu.Unlock() - - k := f.Type.key(p.Src, p.Dst) - mak.Set(&f.seen, k, f.timeNow().Add(f.sessionTimeoutLocked())) - p.Trace("firewall out ok") - return p -} - -func (f *Firewall) HandleIn(p *Packet, iif *Interface) *Packet { - f.mu.Lock() - defer f.mu.Unlock() - - // reverse src and dst because the session table is from the POV - // of outbound packets. - k := f.Type.key(p.Dst, p.Src) - now := f.timeNow() - if now.After(f.seen[k]) { - p.Trace("firewall drop") - return nil - } - p.Trace("firewall in ok") - return p -} - -func (f *Firewall) HandleForward(p *Packet, iif *Interface, oif *Interface) *Packet { - if iif == f.TrustedInterface { - // Treat just like a locally originated packet - return f.HandleOut(p, oif) - } - if oif != f.TrustedInterface { - // Not a possible return packet from our trusted interface, drop. - p.Trace("firewall drop, unexpected oif") - return nil - } - // Otherwise, a session must exist, same as HandleIn. - return f.HandleIn(p, iif) -} - -func (f *Firewall) sessionTimeoutLocked() time.Duration { - if f.SessionTimeout == 0 { - return DefaultSessionTimeout - } - return f.SessionTimeout -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package natlab + +import ( + "fmt" + "net/netip" + "sync" + "time" + + "tailscale.com/util/mak" +) + +// FirewallType is the type of filtering a stateful firewall +// does. Values express different modes defined by RFC 4787. +type FirewallType int + +const ( + // AddressAndPortDependentFirewall specifies a destination + // address-and-port dependent firewall. Outbound traffic to an + // ip:port authorizes traffic from that ip:port exactly, and + // nothing else. + AddressAndPortDependentFirewall FirewallType = iota + // AddressDependentFirewall specifies a destination address + // dependent firewall. Once outbound traffic has been seen to an + // IP address, that IP address can talk back from any port. + AddressDependentFirewall + // EndpointIndependentFirewall specifies a destination endpoint + // independent firewall. Once outbound traffic has been seen from + // a source, anyone can talk back to that source. + EndpointIndependentFirewall +) + +// fwKey is the lookup key for a firewall session. While it contains a +// 4-tuple ({src,dst} {ip,port}), some FirewallTypes will zero out +// some fields, so in practice the key is either a 2-tuple (src only), +// 3-tuple (src ip+port and dst ip) or 4-tuple (src+dst ip+port). +type fwKey struct { + src netip.AddrPort + dst netip.AddrPort +} + +// key returns an fwKey for the given src and dst, trimmed according +// to the FirewallType. fwKeys are always constructed from the +// "outbound" point of view (i.e. src is the "trusted" side of the +// world), it's the caller's responsibility to swap src and dst in the +// call to key when processing packets inbound from the "untrusted" +// world. +func (s FirewallType) key(src, dst netip.AddrPort) fwKey { + k := fwKey{src: src} + switch s { + case EndpointIndependentFirewall: + case AddressDependentFirewall: + k.dst = netip.AddrPortFrom(dst.Addr(), k.dst.Port()) + case AddressAndPortDependentFirewall: + k.dst = dst + default: + panic(fmt.Sprintf("unknown firewall selectivity %v", s)) + } + return k +} + +// DefaultSessionTimeout is the default timeout for a firewall +// session. +const DefaultSessionTimeout = 30 * time.Second + +// Firewall is a simple stateful firewall that allows all outbound +// traffic and filters inbound traffic based on recently seen outbound +// traffic. Its HandlePacket method should be attached to a Machine to +// give it a stateful firewall. +type Firewall struct { + // SessionTimeout is the lifetime of idle sessions in the firewall + // state. Packets transiting from the TrustedInterface reset the + // session lifetime to SessionTimeout. If zero, + // DefaultSessionTimeout is used. + SessionTimeout time.Duration + // Type specifies how precisely return traffic must match + // previously seen outbound traffic to be allowed. Defaults to + // AddressAndPortDependentFirewall. + Type FirewallType + // TrustedInterface is an optional interface that is considered + // trusted in addition to PacketConns local to the Machine. All + // other interfaces can only respond to traffic from + // TrustedInterface or the local host. + TrustedInterface *Interface + // TimeNow is a function returning the current time. If nil, + // time.Now is used. + TimeNow func() time.Time + + // TODO: refresh directionality: outbound-only, both + + mu sync.Mutex + seen map[fwKey]time.Time // session -> deadline +} + +func (f *Firewall) timeNow() time.Time { + if f.TimeNow != nil { + return f.TimeNow() + } + return time.Now() +} + +// Reset drops all firewall state, forgetting all flows. +func (f *Firewall) Reset() { + f.mu.Lock() + defer f.mu.Unlock() + f.seen = nil +} + +func (f *Firewall) HandleOut(p *Packet, oif *Interface) *Packet { + f.mu.Lock() + defer f.mu.Unlock() + + k := f.Type.key(p.Src, p.Dst) + mak.Set(&f.seen, k, f.timeNow().Add(f.sessionTimeoutLocked())) + p.Trace("firewall out ok") + return p +} + +func (f *Firewall) HandleIn(p *Packet, iif *Interface) *Packet { + f.mu.Lock() + defer f.mu.Unlock() + + // reverse src and dst because the session table is from the POV + // of outbound packets. + k := f.Type.key(p.Dst, p.Src) + now := f.timeNow() + if now.After(f.seen[k]) { + p.Trace("firewall drop") + return nil + } + p.Trace("firewall in ok") + return p +} + +func (f *Firewall) HandleForward(p *Packet, iif *Interface, oif *Interface) *Packet { + if iif == f.TrustedInterface { + // Treat just like a locally originated packet + return f.HandleOut(p, oif) + } + if oif != f.TrustedInterface { + // Not a possible return packet from our trusted interface, drop. + p.Trace("firewall drop, unexpected oif") + return nil + } + // Otherwise, a session must exist, same as HandleIn. + return f.HandleIn(p, iif) +} + +func (f *Firewall) sessionTimeoutLocked() time.Duration { + if f.SessionTimeout == 0 { + return DefaultSessionTimeout + } + return f.SessionTimeout +} diff --git a/tstest/natlab/nat.go b/tstest/natlab/nat.go index 36b1322cd..d756c5bf1 100644 --- a/tstest/natlab/nat.go +++ b/tstest/natlab/nat.go @@ -1,252 +1,252 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package natlab - -import ( - "context" - "fmt" - "net" - "net/netip" - "sync" - "time" -) - -// mapping is the state of an allocated NAT session. -type mapping struct { - lanSrc netip.AddrPort - lanDst netip.AddrPort - wanSrc netip.AddrPort - deadline time.Time - - // pc is a PacketConn that reserves an outbound port on the NAT's - // WAN interface. We do this because ListenPacket already has - // random port selection logic built in. Additionally this means - // that concurrent use of ListenPacket for connections originating - // from the NAT box won't conflict with NAT mappings, since both - // use PacketConn to reserve ports on the machine. - pc net.PacketConn -} - -// NATType is the mapping behavior of a NAT device. Values express -// different modes defined by RFC 4787. -type NATType int - -const ( - // EndpointIndependentNAT specifies a destination endpoint - // independent NAT. All traffic from a source ip:port gets mapped - // to a single WAN ip:port. - EndpointIndependentNAT NATType = iota - // AddressDependentNAT specifies a destination address dependent - // NAT. Every distinct destination IP gets its own WAN ip:port - // allocation. - AddressDependentNAT - // AddressAndPortDependentNAT specifies a destination - // address-and-port dependent NAT. Every distinct destination - // ip:port gets its own WAN ip:port allocation. - AddressAndPortDependentNAT -) - -// natKey is the lookup key for a NAT session. While it contains a -// 4-tuple ({src,dst} {ip,port}), some NATTypes will zero out some -// fields, so in practice the key is either a 2-tuple (src only), -// 3-tuple (src ip+port and dst ip) or 4-tuple (src+dst ip+port). -type natKey struct { - src, dst netip.AddrPort -} - -func (t NATType) key(src, dst netip.AddrPort) natKey { - k := natKey{src: src} - switch t { - case EndpointIndependentNAT: - case AddressDependentNAT: - k.dst = netip.AddrPortFrom(dst.Addr(), k.dst.Port()) - case AddressAndPortDependentNAT: - k.dst = dst - default: - panic(fmt.Sprintf("unknown NAT type %v", t)) - } - return k -} - -// DefaultMappingTimeout is the default timeout for a NAT mapping. -const DefaultMappingTimeout = 30 * time.Second - -// SNAT44 implements an IPv4-to-IPv4 source NAT (SNAT) translator, with -// optional builtin firewall. -type SNAT44 struct { - // Machine is the machine to which this NAT is attached. Altered - // packets are injected back into this Machine for processing. - Machine *Machine - // ExternalInterface is the "WAN" interface of Machine. Packets - // from other sources get NATed onto this interface. - ExternalInterface *Interface - // Type specifies the mapping allocation behavior for this NAT. - Type NATType - // MappingTimeout is the lifetime of individual NAT sessions. Once - // a session expires, the mapped port effectively "closes" to new - // traffic. If MappingTimeout is 0, DefaultMappingTimeout is used. - MappingTimeout time.Duration - // Firewall is an optional packet handler that will be invoked as - // a firewall during NAT translation. The firewall always sees - // packets in their "LAN form", i.e. before translation in the - // outbound direction and after translation in the inbound - // direction. - Firewall PacketHandler - // TimeNow is a function that returns the current time. If - // nil, time.Now is used. - TimeNow func() time.Time - - mu sync.Mutex - byLAN map[natKey]*mapping // lookup by outbound packet tuple - byWAN map[netip.AddrPort]*mapping // lookup by wan ip:port only -} - -func (n *SNAT44) timeNow() time.Time { - if n.TimeNow != nil { - return n.TimeNow() - } - return time.Now() -} - -func (n *SNAT44) mappingTimeout() time.Duration { - if n.MappingTimeout == 0 { - return DefaultMappingTimeout - } - return n.MappingTimeout -} - -func (n *SNAT44) initLocked() { - if n.byLAN == nil { - n.byLAN = map[natKey]*mapping{} - n.byWAN = map[netip.AddrPort]*mapping{} - } - if n.ExternalInterface.Machine() != n.Machine { - panic(fmt.Sprintf("NAT given interface %s that is not part of given machine %s", n.ExternalInterface, n.Machine.Name)) - } -} - -func (n *SNAT44) HandleOut(p *Packet, oif *Interface) *Packet { - // NATs don't affect locally originated packets. - if n.Firewall != nil { - return n.Firewall.HandleOut(p, oif) - } - return p -} - -func (n *SNAT44) HandleIn(p *Packet, iif *Interface) *Packet { - if iif != n.ExternalInterface { - // NAT can't apply, defer to firewall. - if n.Firewall != nil { - return n.Firewall.HandleIn(p, iif) - } - return p - } - - n.mu.Lock() - defer n.mu.Unlock() - n.initLocked() - - now := n.timeNow() - mapping := n.byWAN[p.Dst] - if mapping == nil || now.After(mapping.deadline) { - // NAT didn't hit, defer to firewall or allow in for local - // socket handling. - if n.Firewall != nil { - return n.Firewall.HandleIn(p, iif) - } - return p - } - - p.Dst = mapping.lanSrc - p.Trace("dnat to %v", p.Dst) - // Don't process firewall here. We mutated the packet such that - // it's no longer destined locally, so we'll get reinvoked as - // HandleForward and need to process the altered packet there. - return p -} - -func (n *SNAT44) HandleForward(p *Packet, iif, oif *Interface) *Packet { - switch { - case oif == n.ExternalInterface: - if p.Src.Addr() == oif.V4() { - // Packet already NATed and is just retraversing Forward, - // don't touch it again. - return p - } - - if n.Firewall != nil { - p2 := n.Firewall.HandleForward(p, iif, oif) - if p2 == nil { - // firewall dropped, done - return nil - } - if !p.Equivalent(p2) { - // firewall mutated packet? Weird, but okay. - return p2 - } - } - - n.mu.Lock() - defer n.mu.Unlock() - n.initLocked() - - k := n.Type.key(p.Src, p.Dst) - now := n.timeNow() - m := n.byLAN[k] - if m == nil || now.After(m.deadline) { - pc, wanAddr := n.allocateMappedPort() - m = &mapping{ - lanSrc: p.Src, - lanDst: p.Dst, - wanSrc: wanAddr, - pc: pc, - } - n.byLAN[k] = m - n.byWAN[wanAddr] = m - } - m.deadline = now.Add(n.mappingTimeout()) - p.Src = m.wanSrc - p.Trace("snat from %v", p.Src) - return p - case iif == n.ExternalInterface: - // Packet was already un-NAT-ed, we just need to either - // firewall it or let it through. - if n.Firewall != nil { - return n.Firewall.HandleForward(p, iif, oif) - } - return p - default: - // No NAT applies, invoke firewall or drop. - if n.Firewall != nil { - return n.Firewall.HandleForward(p, iif, oif) - } - return nil - } -} - -func (n *SNAT44) allocateMappedPort() (net.PacketConn, netip.AddrPort) { - // Clean up old entries before trying to allocate, to free up any - // expired ports. - n.gc() - - ip := n.ExternalInterface.V4() - pc, err := n.Machine.ListenPacket(context.Background(), "udp", net.JoinHostPort(ip.String(), "0")) - if err != nil { - panic(fmt.Sprintf("ran out of NAT ports: %v", err)) - } - addr := netip.AddrPortFrom(ip, uint16(pc.LocalAddr().(*net.UDPAddr).Port)) - return pc, addr -} - -func (n *SNAT44) gc() { - now := n.timeNow() - for _, m := range n.byLAN { - if !now.After(m.deadline) { - continue - } - m.pc.Close() - delete(n.byLAN, n.Type.key(m.lanSrc, m.lanDst)) - delete(n.byWAN, m.wanSrc) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package natlab + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + "time" +) + +// mapping is the state of an allocated NAT session. +type mapping struct { + lanSrc netip.AddrPort + lanDst netip.AddrPort + wanSrc netip.AddrPort + deadline time.Time + + // pc is a PacketConn that reserves an outbound port on the NAT's + // WAN interface. We do this because ListenPacket already has + // random port selection logic built in. Additionally this means + // that concurrent use of ListenPacket for connections originating + // from the NAT box won't conflict with NAT mappings, since both + // use PacketConn to reserve ports on the machine. + pc net.PacketConn +} + +// NATType is the mapping behavior of a NAT device. Values express +// different modes defined by RFC 4787. +type NATType int + +const ( + // EndpointIndependentNAT specifies a destination endpoint + // independent NAT. All traffic from a source ip:port gets mapped + // to a single WAN ip:port. + EndpointIndependentNAT NATType = iota + // AddressDependentNAT specifies a destination address dependent + // NAT. Every distinct destination IP gets its own WAN ip:port + // allocation. + AddressDependentNAT + // AddressAndPortDependentNAT specifies a destination + // address-and-port dependent NAT. Every distinct destination + // ip:port gets its own WAN ip:port allocation. + AddressAndPortDependentNAT +) + +// natKey is the lookup key for a NAT session. While it contains a +// 4-tuple ({src,dst} {ip,port}), some NATTypes will zero out some +// fields, so in practice the key is either a 2-tuple (src only), +// 3-tuple (src ip+port and dst ip) or 4-tuple (src+dst ip+port). +type natKey struct { + src, dst netip.AddrPort +} + +func (t NATType) key(src, dst netip.AddrPort) natKey { + k := natKey{src: src} + switch t { + case EndpointIndependentNAT: + case AddressDependentNAT: + k.dst = netip.AddrPortFrom(dst.Addr(), k.dst.Port()) + case AddressAndPortDependentNAT: + k.dst = dst + default: + panic(fmt.Sprintf("unknown NAT type %v", t)) + } + return k +} + +// DefaultMappingTimeout is the default timeout for a NAT mapping. +const DefaultMappingTimeout = 30 * time.Second + +// SNAT44 implements an IPv4-to-IPv4 source NAT (SNAT) translator, with +// optional builtin firewall. +type SNAT44 struct { + // Machine is the machine to which this NAT is attached. Altered + // packets are injected back into this Machine for processing. + Machine *Machine + // ExternalInterface is the "WAN" interface of Machine. Packets + // from other sources get NATed onto this interface. + ExternalInterface *Interface + // Type specifies the mapping allocation behavior for this NAT. + Type NATType + // MappingTimeout is the lifetime of individual NAT sessions. Once + // a session expires, the mapped port effectively "closes" to new + // traffic. If MappingTimeout is 0, DefaultMappingTimeout is used. + MappingTimeout time.Duration + // Firewall is an optional packet handler that will be invoked as + // a firewall during NAT translation. The firewall always sees + // packets in their "LAN form", i.e. before translation in the + // outbound direction and after translation in the inbound + // direction. + Firewall PacketHandler + // TimeNow is a function that returns the current time. If + // nil, time.Now is used. + TimeNow func() time.Time + + mu sync.Mutex + byLAN map[natKey]*mapping // lookup by outbound packet tuple + byWAN map[netip.AddrPort]*mapping // lookup by wan ip:port only +} + +func (n *SNAT44) timeNow() time.Time { + if n.TimeNow != nil { + return n.TimeNow() + } + return time.Now() +} + +func (n *SNAT44) mappingTimeout() time.Duration { + if n.MappingTimeout == 0 { + return DefaultMappingTimeout + } + return n.MappingTimeout +} + +func (n *SNAT44) initLocked() { + if n.byLAN == nil { + n.byLAN = map[natKey]*mapping{} + n.byWAN = map[netip.AddrPort]*mapping{} + } + if n.ExternalInterface.Machine() != n.Machine { + panic(fmt.Sprintf("NAT given interface %s that is not part of given machine %s", n.ExternalInterface, n.Machine.Name)) + } +} + +func (n *SNAT44) HandleOut(p *Packet, oif *Interface) *Packet { + // NATs don't affect locally originated packets. + if n.Firewall != nil { + return n.Firewall.HandleOut(p, oif) + } + return p +} + +func (n *SNAT44) HandleIn(p *Packet, iif *Interface) *Packet { + if iif != n.ExternalInterface { + // NAT can't apply, defer to firewall. + if n.Firewall != nil { + return n.Firewall.HandleIn(p, iif) + } + return p + } + + n.mu.Lock() + defer n.mu.Unlock() + n.initLocked() + + now := n.timeNow() + mapping := n.byWAN[p.Dst] + if mapping == nil || now.After(mapping.deadline) { + // NAT didn't hit, defer to firewall or allow in for local + // socket handling. + if n.Firewall != nil { + return n.Firewall.HandleIn(p, iif) + } + return p + } + + p.Dst = mapping.lanSrc + p.Trace("dnat to %v", p.Dst) + // Don't process firewall here. We mutated the packet such that + // it's no longer destined locally, so we'll get reinvoked as + // HandleForward and need to process the altered packet there. + return p +} + +func (n *SNAT44) HandleForward(p *Packet, iif, oif *Interface) *Packet { + switch { + case oif == n.ExternalInterface: + if p.Src.Addr() == oif.V4() { + // Packet already NATed and is just retraversing Forward, + // don't touch it again. + return p + } + + if n.Firewall != nil { + p2 := n.Firewall.HandleForward(p, iif, oif) + if p2 == nil { + // firewall dropped, done + return nil + } + if !p.Equivalent(p2) { + // firewall mutated packet? Weird, but okay. + return p2 + } + } + + n.mu.Lock() + defer n.mu.Unlock() + n.initLocked() + + k := n.Type.key(p.Src, p.Dst) + now := n.timeNow() + m := n.byLAN[k] + if m == nil || now.After(m.deadline) { + pc, wanAddr := n.allocateMappedPort() + m = &mapping{ + lanSrc: p.Src, + lanDst: p.Dst, + wanSrc: wanAddr, + pc: pc, + } + n.byLAN[k] = m + n.byWAN[wanAddr] = m + } + m.deadline = now.Add(n.mappingTimeout()) + p.Src = m.wanSrc + p.Trace("snat from %v", p.Src) + return p + case iif == n.ExternalInterface: + // Packet was already un-NAT-ed, we just need to either + // firewall it or let it through. + if n.Firewall != nil { + return n.Firewall.HandleForward(p, iif, oif) + } + return p + default: + // No NAT applies, invoke firewall or drop. + if n.Firewall != nil { + return n.Firewall.HandleForward(p, iif, oif) + } + return nil + } +} + +func (n *SNAT44) allocateMappedPort() (net.PacketConn, netip.AddrPort) { + // Clean up old entries before trying to allocate, to free up any + // expired ports. + n.gc() + + ip := n.ExternalInterface.V4() + pc, err := n.Machine.ListenPacket(context.Background(), "udp", net.JoinHostPort(ip.String(), "0")) + if err != nil { + panic(fmt.Sprintf("ran out of NAT ports: %v", err)) + } + addr := netip.AddrPortFrom(ip, uint16(pc.LocalAddr().(*net.UDPAddr).Port)) + return pc, addr +} + +func (n *SNAT44) gc() { + now := n.timeNow() + for _, m := range n.byLAN { + if !now.After(m.deadline) { + continue + } + m.pc.Close() + delete(n.byLAN, n.Type.key(m.lanSrc, m.lanDst)) + delete(n.byWAN, m.wanSrc) + } +} diff --git a/tstest/tstest.go b/tstest/tstest.go index 118aa3827..2d0d1351e 100644 --- a/tstest/tstest.go +++ b/tstest/tstest.go @@ -1,95 +1,95 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package tstest provides utilities for use in unit tests. -package tstest - -import ( - "context" - "os" - "strconv" - "strings" - "sync/atomic" - "testing" - "time" - - "tailscale.com/envknob" - "tailscale.com/logtail/backoff" - "tailscale.com/types/logger" - "tailscale.com/util/cibuild" -) - -// Replace replaces the value of target with val. -// The old value is restored when the test ends. -func Replace[T any](t testing.TB, target *T, val T) { - t.Helper() - if target == nil { - t.Fatalf("Replace: nil pointer") - panic("unreachable") // pacify staticcheck - } - old := *target - t.Cleanup(func() { - *target = old - }) - - *target = val - return -} - -// WaitFor retries try for up to maxWait. -// It returns nil once try returns nil the first time. -// If maxWait passes without success, it returns try's last error. -func WaitFor(maxWait time.Duration, try func() error) error { - bo := backoff.NewBackoff("wait-for", logger.Discard, maxWait/4) - deadline := time.Now().Add(maxWait) - var err error - for time.Now().Before(deadline) { - err = try() - if err == nil { - break - } - bo.BackOff(context.Background(), err) - } - return err -} - -var testNum atomic.Int32 - -// Shard skips t if it's not running if the TS_TEST_SHARD test shard is set to -// "n/m" and this test execution number in the process mod m is not equal to n-1. -// That is, to run with 4 shards, set TS_TEST_SHARD=1/4, ..., TS_TEST_SHARD=4/4 -// for the four jobs. -func Shard(t testing.TB) { - e := os.Getenv("TS_TEST_SHARD") - a, b, ok := strings.Cut(e, "/") - if !ok { - return - } - wantShard, _ := strconv.ParseInt(a, 10, 32) - shards, _ := strconv.ParseInt(b, 10, 32) - if wantShard == 0 || shards == 0 { - return - } - - shard := ((testNum.Add(1) - 1) % int32(shards)) + 1 - if shard != int32(wantShard) { - t.Skipf("skipping shard %d/%d (process has TS_TEST_SHARD=%q)", shard, shards, e) - } -} - -// SkipOnUnshardedCI skips t if we're in CI and the TS_TEST_SHARD -// environment variable isn't set. -func SkipOnUnshardedCI(t testing.TB) { - if cibuild.On() && os.Getenv("TS_TEST_SHARD") == "" { - t.Skip("skipping on CI without TS_TEST_SHARD") - } -} - -var serializeParallel = envknob.RegisterBool("TS_SERIAL_TESTS") - -// Parallel calls t.Parallel, unless TS_SERIAL_TESTS is set true. -func Parallel(t *testing.T) { - if !serializeParallel() { - t.Parallel() - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tstest provides utilities for use in unit tests. +package tstest + +import ( + "context" + "os" + "strconv" + "strings" + "sync/atomic" + "testing" + "time" + + "tailscale.com/envknob" + "tailscale.com/logtail/backoff" + "tailscale.com/types/logger" + "tailscale.com/util/cibuild" +) + +// Replace replaces the value of target with val. +// The old value is restored when the test ends. +func Replace[T any](t testing.TB, target *T, val T) { + t.Helper() + if target == nil { + t.Fatalf("Replace: nil pointer") + panic("unreachable") // pacify staticcheck + } + old := *target + t.Cleanup(func() { + *target = old + }) + + *target = val + return +} + +// WaitFor retries try for up to maxWait. +// It returns nil once try returns nil the first time. +// If maxWait passes without success, it returns try's last error. +func WaitFor(maxWait time.Duration, try func() error) error { + bo := backoff.NewBackoff("wait-for", logger.Discard, maxWait/4) + deadline := time.Now().Add(maxWait) + var err error + for time.Now().Before(deadline) { + err = try() + if err == nil { + break + } + bo.BackOff(context.Background(), err) + } + return err +} + +var testNum atomic.Int32 + +// Shard skips t if it's not running if the TS_TEST_SHARD test shard is set to +// "n/m" and this test execution number in the process mod m is not equal to n-1. +// That is, to run with 4 shards, set TS_TEST_SHARD=1/4, ..., TS_TEST_SHARD=4/4 +// for the four jobs. +func Shard(t testing.TB) { + e := os.Getenv("TS_TEST_SHARD") + a, b, ok := strings.Cut(e, "/") + if !ok { + return + } + wantShard, _ := strconv.ParseInt(a, 10, 32) + shards, _ := strconv.ParseInt(b, 10, 32) + if wantShard == 0 || shards == 0 { + return + } + + shard := ((testNum.Add(1) - 1) % int32(shards)) + 1 + if shard != int32(wantShard) { + t.Skipf("skipping shard %d/%d (process has TS_TEST_SHARD=%q)", shard, shards, e) + } +} + +// SkipOnUnshardedCI skips t if we're in CI and the TS_TEST_SHARD +// environment variable isn't set. +func SkipOnUnshardedCI(t testing.TB) { + if cibuild.On() && os.Getenv("TS_TEST_SHARD") == "" { + t.Skip("skipping on CI without TS_TEST_SHARD") + } +} + +var serializeParallel = envknob.RegisterBool("TS_SERIAL_TESTS") + +// Parallel calls t.Parallel, unless TS_SERIAL_TESTS is set true. +func Parallel(t *testing.T) { + if !serializeParallel() { + t.Parallel() + } +} diff --git a/tstest/tstest_test.go b/tstest/tstest_test.go index 20a9f7bf1..e988d5d56 100644 --- a/tstest/tstest_test.go +++ b/tstest/tstest_test.go @@ -1,24 +1,24 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstest - -import "testing" - -func TestReplace(t *testing.T) { - before := "before" - done := false - t.Run("replace", func(t *testing.T) { - Replace(t, &before, "after") - if before != "after" { - t.Errorf("before = %q; want %q", before, "after") - } - done = true - }) - if !done { - t.Fatal("subtest didn't run") - } - if before != "before" { - t.Errorf("before = %q; want %q", before, "before") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstest + +import "testing" + +func TestReplace(t *testing.T) { + before := "before" + done := false + t.Run("replace", func(t *testing.T) { + Replace(t, &before, "after") + if before != "after" { + t.Errorf("before = %q; want %q", before, "after") + } + done = true + }) + if !done { + t.Fatal("subtest didn't run") + } + if before != "before" { + t.Errorf("before = %q; want %q", before, "before") + } +} diff --git a/tstime/mono/mono.go b/tstime/mono/mono.go index 94dca7d79..260e02b0f 100644 --- a/tstime/mono/mono.go +++ b/tstime/mono/mono.go @@ -1,127 +1,127 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package mono provides fast monotonic time. -// On most platforms, mono.Now is about 2x faster than time.Now. -// However, time.Now is really fast, and nicer to use. -// -// For almost all purposes, you should use time.Now. -// -// Package mono exists because we get the current time multiple -// times per network packet, at which point it makes a -// measurable difference. -package mono - -import ( - "fmt" - "sync/atomic" - "time" -) - -// Time is the number of nanoseconds elapsed since an unspecified reference start time. -type Time int64 - -// Now returns the current monotonic time. -func Now() Time { - // On a newly started machine, the monotonic clock might be very near zero. - // Thus mono.Time(0).Before(mono.Now.Add(-time.Minute)) might yield true. - // The corresponding package time expression never does, if the wall clock is correct. - // Preserve this correspondence by increasing the "base" monotonic clock by a fair amount. - const baseOffset int64 = 1 << 55 // approximately 10,000 hours in nanoseconds - return Time(int64(time.Since(baseWall)) + baseOffset) -} - -// Since returns the time elapsed since t. -func Since(t Time) time.Duration { - return time.Duration(Now() - t) -} - -// Sub returns t-n, the duration from n to t. -func (t Time) Sub(n Time) time.Duration { - return time.Duration(t - n) -} - -// Add returns t+d. -func (t Time) Add(d time.Duration) Time { - return t + Time(d) -} - -// After reports t > n, whether t is after n. -func (t Time) After(n Time) bool { - return t > n -} - -// Before reports t < n, whether t is before n. -func (t Time) Before(n Time) bool { - return t < n -} - -// IsZero reports whether t == 0. -func (t Time) IsZero() bool { - return t == 0 -} - -// StoreAtomic does an atomic store *t = new. -func (t *Time) StoreAtomic(new Time) { - atomic.StoreInt64((*int64)(t), int64(new)) -} - -// LoadAtomic does an atomic load *t. -func (t *Time) LoadAtomic() Time { - return Time(atomic.LoadInt64((*int64)(t))) -} - -// baseWall and baseMono are a pair of almost-identical times used to correlate a Time with a wall time. -var ( - baseWall time.Time - baseMono Time -) - -func init() { - baseWall = time.Now() - baseMono = Now() -} - -// String prints t, including an estimated equivalent wall clock. -// This is best-effort only, for rough debugging purposes only. -// Since t is a monotonic time, it can vary from the actual wall clock by arbitrary amounts. -// Even in the best of circumstances, it may vary by a few milliseconds. -func (t Time) String() string { - return fmt.Sprintf("mono.Time(ns=%d, estimated wall=%v)", int64(t), baseWall.Add(t.Sub(baseMono)).Truncate(0)) -} - -// WallTime returns an approximate wall time that corresponded to t. -func (t Time) WallTime() time.Time { - if !t.IsZero() { - return baseWall.Add(t.Sub(baseMono)).Truncate(0) - } - return time.Time{} -} - -// MarshalJSON formats t for JSON as if it were a time.Time. -// We format Time this way for backwards-compatibility. -// Time does not survive a MarshalJSON/UnmarshalJSON round trip unchanged -// across different invocations of the Go process. This is best-effort only. -// Since t is a monotonic time, it can vary from the actual wall clock by arbitrary amounts. -// Even in the best of circumstances, it may vary by a few milliseconds. -func (t Time) MarshalJSON() ([]byte, error) { - tt := t.WallTime() - return tt.MarshalJSON() -} - -// UnmarshalJSON sets t according to data. -// Time does not survive a MarshalJSON/UnmarshalJSON round trip unchanged -// across different invocations of the Go process. This is best-effort only. -func (t *Time) UnmarshalJSON(data []byte) error { - var tt time.Time - err := tt.UnmarshalJSON(data) - if err != nil { - return err - } - if tt.IsZero() { - *t = 0 - return nil - } - *t = baseMono.Add(tt.Sub(baseWall)) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package mono provides fast monotonic time. +// On most platforms, mono.Now is about 2x faster than time.Now. +// However, time.Now is really fast, and nicer to use. +// +// For almost all purposes, you should use time.Now. +// +// Package mono exists because we get the current time multiple +// times per network packet, at which point it makes a +// measurable difference. +package mono + +import ( + "fmt" + "sync/atomic" + "time" +) + +// Time is the number of nanoseconds elapsed since an unspecified reference start time. +type Time int64 + +// Now returns the current monotonic time. +func Now() Time { + // On a newly started machine, the monotonic clock might be very near zero. + // Thus mono.Time(0).Before(mono.Now.Add(-time.Minute)) might yield true. + // The corresponding package time expression never does, if the wall clock is correct. + // Preserve this correspondence by increasing the "base" monotonic clock by a fair amount. + const baseOffset int64 = 1 << 55 // approximately 10,000 hours in nanoseconds + return Time(int64(time.Since(baseWall)) + baseOffset) +} + +// Since returns the time elapsed since t. +func Since(t Time) time.Duration { + return time.Duration(Now() - t) +} + +// Sub returns t-n, the duration from n to t. +func (t Time) Sub(n Time) time.Duration { + return time.Duration(t - n) +} + +// Add returns t+d. +func (t Time) Add(d time.Duration) Time { + return t + Time(d) +} + +// After reports t > n, whether t is after n. +func (t Time) After(n Time) bool { + return t > n +} + +// Before reports t < n, whether t is before n. +func (t Time) Before(n Time) bool { + return t < n +} + +// IsZero reports whether t == 0. +func (t Time) IsZero() bool { + return t == 0 +} + +// StoreAtomic does an atomic store *t = new. +func (t *Time) StoreAtomic(new Time) { + atomic.StoreInt64((*int64)(t), int64(new)) +} + +// LoadAtomic does an atomic load *t. +func (t *Time) LoadAtomic() Time { + return Time(atomic.LoadInt64((*int64)(t))) +} + +// baseWall and baseMono are a pair of almost-identical times used to correlate a Time with a wall time. +var ( + baseWall time.Time + baseMono Time +) + +func init() { + baseWall = time.Now() + baseMono = Now() +} + +// String prints t, including an estimated equivalent wall clock. +// This is best-effort only, for rough debugging purposes only. +// Since t is a monotonic time, it can vary from the actual wall clock by arbitrary amounts. +// Even in the best of circumstances, it may vary by a few milliseconds. +func (t Time) String() string { + return fmt.Sprintf("mono.Time(ns=%d, estimated wall=%v)", int64(t), baseWall.Add(t.Sub(baseMono)).Truncate(0)) +} + +// WallTime returns an approximate wall time that corresponded to t. +func (t Time) WallTime() time.Time { + if !t.IsZero() { + return baseWall.Add(t.Sub(baseMono)).Truncate(0) + } + return time.Time{} +} + +// MarshalJSON formats t for JSON as if it were a time.Time. +// We format Time this way for backwards-compatibility. +// Time does not survive a MarshalJSON/UnmarshalJSON round trip unchanged +// across different invocations of the Go process. This is best-effort only. +// Since t is a monotonic time, it can vary from the actual wall clock by arbitrary amounts. +// Even in the best of circumstances, it may vary by a few milliseconds. +func (t Time) MarshalJSON() ([]byte, error) { + tt := t.WallTime() + return tt.MarshalJSON() +} + +// UnmarshalJSON sets t according to data. +// Time does not survive a MarshalJSON/UnmarshalJSON round trip unchanged +// across different invocations of the Go process. This is best-effort only. +func (t *Time) UnmarshalJSON(data []byte) error { + var tt time.Time + err := tt.UnmarshalJSON(data) + if err != nil { + return err + } + if tt.IsZero() { + *t = 0 + return nil + } + *t = baseMono.Add(tt.Sub(baseWall)) + return nil +} diff --git a/tstime/rate/rate.go b/tstime/rate/rate.go index 19dc26e6a..f0473862a 100644 --- a/tstime/rate/rate.go +++ b/tstime/rate/rate.go @@ -1,90 +1,90 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// This is a modified, simplified version of code from golang.org/x/time/rate. - -// Copyright 2015 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package rate provides a rate limiter. -package rate - -import ( - "sync" - "time" - - "tailscale.com/tstime/mono" -) - -// Limit defines the maximum frequency of some events. -// Limit is represented as number of events per second. -// A zero Limit is invalid. -type Limit float64 - -// Every converts a minimum time interval between events to a Limit. -func Every(interval time.Duration) Limit { - if interval <= 0 { - panic("invalid interval") - } - return 1 / Limit(interval.Seconds()) -} - -// A Limiter controls how frequently events are allowed to happen. -// It implements a [token bucket] of a particular size b, -// initially full and refilled at rate r tokens per second. -// Informally, in any large enough time interval, -// the Limiter limits the rate to r tokens per second, -// with a maximum burst size of b events. -// Use NewLimiter to create non-zero Limiters. -// -// [token bucket]: https://en.wikipedia.org/wiki/Token_bucket -type Limiter struct { - limit Limit - burst float64 - mu sync.Mutex // protects following fields - tokens float64 // number of tokens currently in bucket - last mono.Time // the last time the limiter's tokens field was updated -} - -// NewLimiter returns a new Limiter that allows events up to rate r and permits -// bursts of at most b tokens. -func NewLimiter(r Limit, b int) *Limiter { - if b < 1 { - panic("bad burst, must be at least 1") - } - return &Limiter{limit: r, burst: float64(b)} -} - -// Allow reports whether an event may happen now. -func (lim *Limiter) Allow() bool { - return lim.allow(mono.Now()) -} - -func (lim *Limiter) allow(now mono.Time) bool { - lim.mu.Lock() - defer lim.mu.Unlock() - - // If time has moved backwards, look around awkwardly and pretend nothing happened. - if now.Before(lim.last) { - lim.last = now - } - - // Calculate the new number of tokens available due to the passage of time. - elapsed := now.Sub(lim.last) - tokens := lim.tokens + float64(lim.limit)*elapsed.Seconds() - if tokens > lim.burst { - tokens = lim.burst - } - - // Consume a token. - tokens-- - - // Update state. - ok := tokens >= 0 - if ok { - lim.last = now - lim.tokens = tokens - } - return ok -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This is a modified, simplified version of code from golang.org/x/time/rate. + +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package rate provides a rate limiter. +package rate + +import ( + "sync" + "time" + + "tailscale.com/tstime/mono" +) + +// Limit defines the maximum frequency of some events. +// Limit is represented as number of events per second. +// A zero Limit is invalid. +type Limit float64 + +// Every converts a minimum time interval between events to a Limit. +func Every(interval time.Duration) Limit { + if interval <= 0 { + panic("invalid interval") + } + return 1 / Limit(interval.Seconds()) +} + +// A Limiter controls how frequently events are allowed to happen. +// It implements a [token bucket] of a particular size b, +// initially full and refilled at rate r tokens per second. +// Informally, in any large enough time interval, +// the Limiter limits the rate to r tokens per second, +// with a maximum burst size of b events. +// Use NewLimiter to create non-zero Limiters. +// +// [token bucket]: https://en.wikipedia.org/wiki/Token_bucket +type Limiter struct { + limit Limit + burst float64 + mu sync.Mutex // protects following fields + tokens float64 // number of tokens currently in bucket + last mono.Time // the last time the limiter's tokens field was updated +} + +// NewLimiter returns a new Limiter that allows events up to rate r and permits +// bursts of at most b tokens. +func NewLimiter(r Limit, b int) *Limiter { + if b < 1 { + panic("bad burst, must be at least 1") + } + return &Limiter{limit: r, burst: float64(b)} +} + +// Allow reports whether an event may happen now. +func (lim *Limiter) Allow() bool { + return lim.allow(mono.Now()) +} + +func (lim *Limiter) allow(now mono.Time) bool { + lim.mu.Lock() + defer lim.mu.Unlock() + + // If time has moved backwards, look around awkwardly and pretend nothing happened. + if now.Before(lim.last) { + lim.last = now + } + + // Calculate the new number of tokens available due to the passage of time. + elapsed := now.Sub(lim.last) + tokens := lim.tokens + float64(lim.limit)*elapsed.Seconds() + if tokens > lim.burst { + tokens = lim.burst + } + + // Consume a token. + tokens-- + + // Update state. + ok := tokens >= 0 + if ok { + lim.last = now + lim.tokens = tokens + } + return ok +} diff --git a/tstime/tstime.go b/tstime/tstime.go index 22616bca7..1c006355f 100644 --- a/tstime/tstime.go +++ b/tstime/tstime.go @@ -1,185 +1,185 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package tstime defines Tailscale-specific time utilities. -package tstime - -import ( - "context" - "strconv" - "strings" - "time" -) - -// Parse3339 is a wrapper around time.Parse(time.RFC3339, s). -func Parse3339(s string) (time.Time, error) { - return time.Parse(time.RFC3339, s) -} - -// Parse3339B is Parse3339 but for byte slices. -func Parse3339B(b []byte) (time.Time, error) { - var t time.Time - if err := t.UnmarshalText(b); err != nil { - return Parse3339(string(b)) // reproduce same error message - } - return t, nil -} - -// ParseDuration is more expressive than [time.ParseDuration], -// also accepting 'd' (days) and 'w' (weeks) literals. -func ParseDuration(s string) (time.Duration, error) { - for { - end := strings.IndexAny(s, "dw") - if end < 0 { - break - } - start := end - (len(s[:end]) - len(strings.TrimRight(s[:end], "0123456789"))) - n, err := strconv.Atoi(s[start:end]) - if err != nil { - return 0, err - } - hours := 24 - if s[end] == 'w' { - hours *= 7 - } - s = s[:start] + s[end+1:] + strconv.Itoa(n*hours) + "h" - } - return time.ParseDuration(s) -} - -// Sleep is like [time.Sleep] but returns early upon context cancelation. -// It reports whether the full sleep duration was achieved. -func Sleep(ctx context.Context, d time.Duration) bool { - timer := time.NewTimer(d) - defer timer.Stop() - select { - case <-ctx.Done(): - return false - case <-timer.C: - return true - } -} - -// DefaultClock is a wrapper around a Clock. -// It uses StdClock by default if Clock is nil. -type DefaultClock struct{ Clock } - -// TODO: We should make the methods of DefaultClock inlineable -// so that we can optimize for the common case where c.Clock == nil. - -func (c DefaultClock) Now() time.Time { - if c.Clock == nil { - return time.Now() - } - return c.Clock.Now() -} -func (c DefaultClock) NewTimer(d time.Duration) (TimerController, <-chan time.Time) { - if c.Clock == nil { - t := time.NewTimer(d) - return t, t.C - } - return c.Clock.NewTimer(d) -} -func (c DefaultClock) NewTicker(d time.Duration) (TickerController, <-chan time.Time) { - if c.Clock == nil { - t := time.NewTicker(d) - return t, t.C - } - return c.Clock.NewTicker(d) -} -func (c DefaultClock) AfterFunc(d time.Duration, f func()) TimerController { - if c.Clock == nil { - return time.AfterFunc(d, f) - } - return c.Clock.AfterFunc(d, f) -} -func (c DefaultClock) Since(t time.Time) time.Duration { - if c.Clock == nil { - return time.Since(t) - } - return c.Clock.Since(t) -} - -// Clock offers a subset of the functionality from the std/time package. -// Normally, applications will use the StdClock implementation that calls the -// appropriate std/time exported funcs. The advantage of using Clock is that -// tests can substitute a different implementation, allowing the test to control -// time precisely, something required for certain types of tests to be possible -// at all, speeds up execution by not needing to sleep, and can dramatically -// reduce the risk of flakes due to tests executing too slowly or quickly. -type Clock interface { - // Now returns the current time, as in time.Now. - Now() time.Time - // NewTimer returns a timer whose notion of the current time is controlled - // by this Clock. It follows the semantics of time.NewTimer as closely as - // possible but is adapted to return an interface, so the channel needs to - // be returned as well. - NewTimer(d time.Duration) (TimerController, <-chan time.Time) - // NewTicker returns a ticker whose notion of the current time is controlled - // by this Clock. It follows the semantics of time.NewTicker as closely as - // possible but is adapted to return an interface, so the channel needs to - // be returned as well. - NewTicker(d time.Duration) (TickerController, <-chan time.Time) - // AfterFunc returns a ticker whose notion of the current time is controlled - // by this Clock. When the ticker expires, it will call the provided func. - // It follows the semantics of time.AfterFunc. - AfterFunc(d time.Duration, f func()) TimerController - // Since returns the time elapsed since t. - // It follows the semantics of time.Since. - Since(t time.Time) time.Duration -} - -// TickerController offers the receivers of a time.Ticker to ensure -// compatibility with standard timers, but allows for the option of substituting -// a standard timer with something else for testing purposes. -type TickerController interface { - // Reset follows the same semantics as with time.Ticker.Reset. - Reset(d time.Duration) - // Stop follows the same semantics as with time.Ticker.Stop. - Stop() -} - -// TimerController offers the receivers of a time.Timer to ensure -// compatibility with standard timers, but allows for the option of substituting -// a standard timer with something else for testing purposes. -type TimerController interface { - // Reset follows the same semantics as with time.Timer.Reset. - Reset(d time.Duration) bool - // Stop follows the same semantics as with time.Timer.Stop. - Stop() bool -} - -// StdClock is a simple implementation of Clock using the relevant funcs in the -// std/time package. -type StdClock struct{} - -// Now calls time.Now. -func (StdClock) Now() time.Time { - return time.Now() -} - -// NewTimer calls time.NewTimer. As an interface does not allow for struct -// members and other packages cannot add receivers to another package, the -// channel is also returned because it would be otherwise inaccessible. -func (StdClock) NewTimer(d time.Duration) (TimerController, <-chan time.Time) { - t := time.NewTimer(d) - return t, t.C -} - -// NewTicker calls time.NewTicker. As an interface does not allow for struct -// members and other packages cannot add receivers to another package, the -// channel is also returned because it would be otherwise inaccessible. -func (StdClock) NewTicker(d time.Duration) (TickerController, <-chan time.Time) { - t := time.NewTicker(d) - return t, t.C -} - -// AfterFunc calls time.AfterFunc. -func (StdClock) AfterFunc(d time.Duration, f func()) TimerController { - return time.AfterFunc(d, f) -} - -// Since calls time.Since. -func (StdClock) Since(t time.Time) time.Duration { - return time.Since(t) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tstime defines Tailscale-specific time utilities. +package tstime + +import ( + "context" + "strconv" + "strings" + "time" +) + +// Parse3339 is a wrapper around time.Parse(time.RFC3339, s). +func Parse3339(s string) (time.Time, error) { + return time.Parse(time.RFC3339, s) +} + +// Parse3339B is Parse3339 but for byte slices. +func Parse3339B(b []byte) (time.Time, error) { + var t time.Time + if err := t.UnmarshalText(b); err != nil { + return Parse3339(string(b)) // reproduce same error message + } + return t, nil +} + +// ParseDuration is more expressive than [time.ParseDuration], +// also accepting 'd' (days) and 'w' (weeks) literals. +func ParseDuration(s string) (time.Duration, error) { + for { + end := strings.IndexAny(s, "dw") + if end < 0 { + break + } + start := end - (len(s[:end]) - len(strings.TrimRight(s[:end], "0123456789"))) + n, err := strconv.Atoi(s[start:end]) + if err != nil { + return 0, err + } + hours := 24 + if s[end] == 'w' { + hours *= 7 + } + s = s[:start] + s[end+1:] + strconv.Itoa(n*hours) + "h" + } + return time.ParseDuration(s) +} + +// Sleep is like [time.Sleep] but returns early upon context cancelation. +// It reports whether the full sleep duration was achieved. +func Sleep(ctx context.Context, d time.Duration) bool { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return false + case <-timer.C: + return true + } +} + +// DefaultClock is a wrapper around a Clock. +// It uses StdClock by default if Clock is nil. +type DefaultClock struct{ Clock } + +// TODO: We should make the methods of DefaultClock inlineable +// so that we can optimize for the common case where c.Clock == nil. + +func (c DefaultClock) Now() time.Time { + if c.Clock == nil { + return time.Now() + } + return c.Clock.Now() +} +func (c DefaultClock) NewTimer(d time.Duration) (TimerController, <-chan time.Time) { + if c.Clock == nil { + t := time.NewTimer(d) + return t, t.C + } + return c.Clock.NewTimer(d) +} +func (c DefaultClock) NewTicker(d time.Duration) (TickerController, <-chan time.Time) { + if c.Clock == nil { + t := time.NewTicker(d) + return t, t.C + } + return c.Clock.NewTicker(d) +} +func (c DefaultClock) AfterFunc(d time.Duration, f func()) TimerController { + if c.Clock == nil { + return time.AfterFunc(d, f) + } + return c.Clock.AfterFunc(d, f) +} +func (c DefaultClock) Since(t time.Time) time.Duration { + if c.Clock == nil { + return time.Since(t) + } + return c.Clock.Since(t) +} + +// Clock offers a subset of the functionality from the std/time package. +// Normally, applications will use the StdClock implementation that calls the +// appropriate std/time exported funcs. The advantage of using Clock is that +// tests can substitute a different implementation, allowing the test to control +// time precisely, something required for certain types of tests to be possible +// at all, speeds up execution by not needing to sleep, and can dramatically +// reduce the risk of flakes due to tests executing too slowly or quickly. +type Clock interface { + // Now returns the current time, as in time.Now. + Now() time.Time + // NewTimer returns a timer whose notion of the current time is controlled + // by this Clock. It follows the semantics of time.NewTimer as closely as + // possible but is adapted to return an interface, so the channel needs to + // be returned as well. + NewTimer(d time.Duration) (TimerController, <-chan time.Time) + // NewTicker returns a ticker whose notion of the current time is controlled + // by this Clock. It follows the semantics of time.NewTicker as closely as + // possible but is adapted to return an interface, so the channel needs to + // be returned as well. + NewTicker(d time.Duration) (TickerController, <-chan time.Time) + // AfterFunc returns a ticker whose notion of the current time is controlled + // by this Clock. When the ticker expires, it will call the provided func. + // It follows the semantics of time.AfterFunc. + AfterFunc(d time.Duration, f func()) TimerController + // Since returns the time elapsed since t. + // It follows the semantics of time.Since. + Since(t time.Time) time.Duration +} + +// TickerController offers the receivers of a time.Ticker to ensure +// compatibility with standard timers, but allows for the option of substituting +// a standard timer with something else for testing purposes. +type TickerController interface { + // Reset follows the same semantics as with time.Ticker.Reset. + Reset(d time.Duration) + // Stop follows the same semantics as with time.Ticker.Stop. + Stop() +} + +// TimerController offers the receivers of a time.Timer to ensure +// compatibility with standard timers, but allows for the option of substituting +// a standard timer with something else for testing purposes. +type TimerController interface { + // Reset follows the same semantics as with time.Timer.Reset. + Reset(d time.Duration) bool + // Stop follows the same semantics as with time.Timer.Stop. + Stop() bool +} + +// StdClock is a simple implementation of Clock using the relevant funcs in the +// std/time package. +type StdClock struct{} + +// Now calls time.Now. +func (StdClock) Now() time.Time { + return time.Now() +} + +// NewTimer calls time.NewTimer. As an interface does not allow for struct +// members and other packages cannot add receivers to another package, the +// channel is also returned because it would be otherwise inaccessible. +func (StdClock) NewTimer(d time.Duration) (TimerController, <-chan time.Time) { + t := time.NewTimer(d) + return t, t.C +} + +// NewTicker calls time.NewTicker. As an interface does not allow for struct +// members and other packages cannot add receivers to another package, the +// channel is also returned because it would be otherwise inaccessible. +func (StdClock) NewTicker(d time.Duration) (TickerController, <-chan time.Time) { + t := time.NewTicker(d) + return t, t.C +} + +// AfterFunc calls time.AfterFunc. +func (StdClock) AfterFunc(d time.Duration, f func()) TimerController { + return time.AfterFunc(d, f) +} + +// Since calls time.Since. +func (StdClock) Since(t time.Time) time.Duration { + return time.Since(t) +} diff --git a/tstime/tstime_test.go b/tstime/tstime_test.go index 1169408b6..3ffeaf0ff 100644 --- a/tstime/tstime_test.go +++ b/tstime/tstime_test.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tstime - -import ( - "testing" - "time" -) - -func TestParseDuration(t *testing.T) { - tests := []struct { - in string - want time.Duration - }{ - {"1h", time.Hour}, - {"1d", 24 * time.Hour}, - {"365d", 365 * 24 * time.Hour}, - {"12345d", 12345 * 24 * time.Hour}, - {"67890d", 67890 * 24 * time.Hour}, - {"100d", 100 * 24 * time.Hour}, - {"1d1d", 48 * time.Hour}, - {"1h1d", 25 * time.Hour}, - {"1d1h", 25 * time.Hour}, - {"1w", 7 * 24 * time.Hour}, - {"1w1d1h", 8*24*time.Hour + time.Hour}, - {"1w1d1h", 8*24*time.Hour + time.Hour}, - {"1y", 0}, - {"", 0}, - } - for _, tt := range tests { - if got, _ := ParseDuration(tt.in); got != tt.want { - t.Errorf("ParseDuration(%q) = %d; want %d", tt.in, got, tt.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstime + +import ( + "testing" + "time" +) + +func TestParseDuration(t *testing.T) { + tests := []struct { + in string + want time.Duration + }{ + {"1h", time.Hour}, + {"1d", 24 * time.Hour}, + {"365d", 365 * 24 * time.Hour}, + {"12345d", 12345 * 24 * time.Hour}, + {"67890d", 67890 * 24 * time.Hour}, + {"100d", 100 * 24 * time.Hour}, + {"1d1d", 48 * time.Hour}, + {"1h1d", 25 * time.Hour}, + {"1d1h", 25 * time.Hour}, + {"1w", 7 * 24 * time.Hour}, + {"1w1d1h", 8*24*time.Hour + time.Hour}, + {"1w1d1h", 8*24*time.Hour + time.Hour}, + {"1y", 0}, + {"", 0}, + } + for _, tt := range tests { + if got, _ := ParseDuration(tt.in); got != tt.want { + t.Errorf("ParseDuration(%q) = %d; want %d", tt.in, got, tt.want) + } + } +} diff --git a/tsweb/debug_test.go b/tsweb/debug_test.go index 504ec06ba..2a68ab6fb 100644 --- a/tsweb/debug_test.go +++ b/tsweb/debug_test.go @@ -1,208 +1,208 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tsweb - -import ( - "fmt" - "io" - "net/http" - "net/http/httptest" - "runtime" - "strings" - "testing" -) - -func TestDebugger(t *testing.T) { - mux := http.NewServeMux() - - dbg1 := Debugger(mux) - if dbg1 == nil { - t.Fatal("didn't get a debugger from mux") - } - - dbg2 := Debugger(mux) - if dbg2 != dbg1 { - t.Fatal("Debugger returned different debuggers for the same mux") - } - - t.Run("cpu_pprof", func(t *testing.T) { - if testing.Short() { - t.Skip("skipping second long test") - } - switch runtime.GOOS { - case "linux", "darwin": - default: - t.Skipf("skipping test on %v", runtime.GOOS) - } - req := httptest.NewRequest("GET", "/debug/pprof/profile?seconds=1", nil) - req.RemoteAddr = "100.101.102.103:1234" - rec := httptest.NewRecorder() - mux.ServeHTTP(rec, req) - res := rec.Result() - if res.StatusCode != 200 { - t.Errorf("unexpected %v", res.Status) - } - }) -} - -func get(m http.Handler, path, srcIP string) (int, string) { - req := httptest.NewRequest("GET", path, nil) - req.RemoteAddr = srcIP + ":1234" - rec := httptest.NewRecorder() - m.ServeHTTP(rec, req) - return rec.Result().StatusCode, rec.Body.String() -} - -const ( - tsIP = "100.100.100.100" - pubIP = "8.8.8.8" -) - -func TestDebuggerKV(t *testing.T) { - mux := http.NewServeMux() - dbg := Debugger(mux) - dbg.KV("Donuts", 42) - dbg.KV("Secret code", "hunter2") - val := "red" - dbg.KVFunc("Condition", func() any { return val }) - - code, _ := get(mux, "/debug/", pubIP) - if code != 403 { - t.Fatalf("debug access wasn't denied, got %v", code) - } - - code, body := get(mux, "/debug/", tsIP) - if code != 200 { - t.Fatalf("debug access failed, got %v", code) - } - for _, want := range []string{"Donuts", "42", "Secret code", "hunter2", "Condition", "red"} { - if !strings.Contains(body, want) { - t.Errorf("want %q in output, not found", want) - } - } - - val = "green" - code, body = get(mux, "/debug/", tsIP) - if code != 200 { - t.Fatalf("debug access failed, got %v", code) - } - for _, want := range []string{"Condition", "green"} { - if !strings.Contains(body, want) { - t.Errorf("want %q in output, not found", want) - } - } -} - -func TestDebuggerURL(t *testing.T) { - mux := http.NewServeMux() - dbg := Debugger(mux) - dbg.URL("https://www.tailscale.com", "Homepage") - - code, body := get(mux, "/debug/", tsIP) - if code != 200 { - t.Fatalf("debug access failed, got %v", code) - } - for _, want := range []string{"https://www.tailscale.com", "Homepage"} { - if !strings.Contains(body, want) { - t.Errorf("want %q in output, not found", want) - } - } -} - -func TestDebuggerSection(t *testing.T) { - mux := http.NewServeMux() - dbg := Debugger(mux) - dbg.Section(func(w io.Writer, r *http.Request) { - fmt.Fprintf(w, "Test output %v", r.RemoteAddr) - }) - - code, body := get(mux, "/debug/", tsIP) - if code != 200 { - t.Fatalf("debug access failed, got %v", code) - } - want := `Test output 100.100.100.100:1234` - if !strings.Contains(body, want) { - t.Errorf("want %q in output, not found", want) - } -} - -func TestDebuggerHandle(t *testing.T) { - mux := http.NewServeMux() - dbg := Debugger(mux) - dbg.Handle("check", "Consistency check", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "Test output %v", r.RemoteAddr) - })) - - code, body := get(mux, "/debug/", tsIP) - if code != 200 { - t.Fatalf("debug access failed, got %v", code) - } - for _, want := range []string{"/debug/check", "Consistency check"} { - if !strings.Contains(body, want) { - t.Errorf("want %q in output, not found", want) - } - } - - code, _ = get(mux, "/debug/check", pubIP) - if code != 403 { - t.Fatal("/debug/check should be protected, but isn't") - } - - code, body = get(mux, "/debug/check", tsIP) - if code != 200 { - t.Fatal("/debug/check denied debug access") - } - want := "Test output " + tsIP - if !strings.Contains(body, want) { - t.Errorf("want %q in output, not found", want) - } -} - -func ExampleDebugHandler_Handle() { - mux := http.NewServeMux() - dbg := Debugger(mux) - // Registers /debug/flushcache with the given handler, and adds a - // link to /debug/ with the description "Flush caches". - dbg.Handle("flushcache", "Flush caches", http.HandlerFunc(http.NotFound)) -} - -func ExampleDebugHandler_KV() { - mux := http.NewServeMux() - dbg := Debugger(mux) - // Adds two list items to /debug/, showing that the condition is - // red and there are 42 donuts. - dbg.KV("Condition", "red") - dbg.KV("Donuts", 42) -} - -func ExampleDebugHandler_KVFunc() { - mux := http.NewServeMux() - dbg := Debugger(mux) - // Adds an count of page renders to /debug/. Note this example - // isn't concurrency-safe. - views := 0 - dbg.KVFunc("Debug pageviews", func() any { - views = views + 1 - return views - }) - dbg.KV("Donuts", 42) -} - -func ExampleDebugHandler_URL() { - mux := http.NewServeMux() - dbg := Debugger(mux) - // Links to the Tailscale website from /debug/. - dbg.URL("https://www.tailscale.com", "Homepage") -} - -func ExampleDebugHandler_Section() { - mux := http.NewServeMux() - dbg := Debugger(mux) - // Adds a section to /debug/ that dumps the HTTP request of the - // visitor. - dbg.Section(func(w io.Writer, r *http.Request) { - io.WriteString(w, "

Dump of your HTTP request

") - fmt.Fprintf(w, "%#v", r) - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsweb + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "runtime" + "strings" + "testing" +) + +func TestDebugger(t *testing.T) { + mux := http.NewServeMux() + + dbg1 := Debugger(mux) + if dbg1 == nil { + t.Fatal("didn't get a debugger from mux") + } + + dbg2 := Debugger(mux) + if dbg2 != dbg1 { + t.Fatal("Debugger returned different debuggers for the same mux") + } + + t.Run("cpu_pprof", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping second long test") + } + switch runtime.GOOS { + case "linux", "darwin": + default: + t.Skipf("skipping test on %v", runtime.GOOS) + } + req := httptest.NewRequest("GET", "/debug/pprof/profile?seconds=1", nil) + req.RemoteAddr = "100.101.102.103:1234" + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + res := rec.Result() + if res.StatusCode != 200 { + t.Errorf("unexpected %v", res.Status) + } + }) +} + +func get(m http.Handler, path, srcIP string) (int, string) { + req := httptest.NewRequest("GET", path, nil) + req.RemoteAddr = srcIP + ":1234" + rec := httptest.NewRecorder() + m.ServeHTTP(rec, req) + return rec.Result().StatusCode, rec.Body.String() +} + +const ( + tsIP = "100.100.100.100" + pubIP = "8.8.8.8" +) + +func TestDebuggerKV(t *testing.T) { + mux := http.NewServeMux() + dbg := Debugger(mux) + dbg.KV("Donuts", 42) + dbg.KV("Secret code", "hunter2") + val := "red" + dbg.KVFunc("Condition", func() any { return val }) + + code, _ := get(mux, "/debug/", pubIP) + if code != 403 { + t.Fatalf("debug access wasn't denied, got %v", code) + } + + code, body := get(mux, "/debug/", tsIP) + if code != 200 { + t.Fatalf("debug access failed, got %v", code) + } + for _, want := range []string{"Donuts", "42", "Secret code", "hunter2", "Condition", "red"} { + if !strings.Contains(body, want) { + t.Errorf("want %q in output, not found", want) + } + } + + val = "green" + code, body = get(mux, "/debug/", tsIP) + if code != 200 { + t.Fatalf("debug access failed, got %v", code) + } + for _, want := range []string{"Condition", "green"} { + if !strings.Contains(body, want) { + t.Errorf("want %q in output, not found", want) + } + } +} + +func TestDebuggerURL(t *testing.T) { + mux := http.NewServeMux() + dbg := Debugger(mux) + dbg.URL("https://www.tailscale.com", "Homepage") + + code, body := get(mux, "/debug/", tsIP) + if code != 200 { + t.Fatalf("debug access failed, got %v", code) + } + for _, want := range []string{"https://www.tailscale.com", "Homepage"} { + if !strings.Contains(body, want) { + t.Errorf("want %q in output, not found", want) + } + } +} + +func TestDebuggerSection(t *testing.T) { + mux := http.NewServeMux() + dbg := Debugger(mux) + dbg.Section(func(w io.Writer, r *http.Request) { + fmt.Fprintf(w, "Test output %v", r.RemoteAddr) + }) + + code, body := get(mux, "/debug/", tsIP) + if code != 200 { + t.Fatalf("debug access failed, got %v", code) + } + want := `Test output 100.100.100.100:1234` + if !strings.Contains(body, want) { + t.Errorf("want %q in output, not found", want) + } +} + +func TestDebuggerHandle(t *testing.T) { + mux := http.NewServeMux() + dbg := Debugger(mux) + dbg.Handle("check", "Consistency check", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Test output %v", r.RemoteAddr) + })) + + code, body := get(mux, "/debug/", tsIP) + if code != 200 { + t.Fatalf("debug access failed, got %v", code) + } + for _, want := range []string{"/debug/check", "Consistency check"} { + if !strings.Contains(body, want) { + t.Errorf("want %q in output, not found", want) + } + } + + code, _ = get(mux, "/debug/check", pubIP) + if code != 403 { + t.Fatal("/debug/check should be protected, but isn't") + } + + code, body = get(mux, "/debug/check", tsIP) + if code != 200 { + t.Fatal("/debug/check denied debug access") + } + want := "Test output " + tsIP + if !strings.Contains(body, want) { + t.Errorf("want %q in output, not found", want) + } +} + +func ExampleDebugHandler_Handle() { + mux := http.NewServeMux() + dbg := Debugger(mux) + // Registers /debug/flushcache with the given handler, and adds a + // link to /debug/ with the description "Flush caches". + dbg.Handle("flushcache", "Flush caches", http.HandlerFunc(http.NotFound)) +} + +func ExampleDebugHandler_KV() { + mux := http.NewServeMux() + dbg := Debugger(mux) + // Adds two list items to /debug/, showing that the condition is + // red and there are 42 donuts. + dbg.KV("Condition", "red") + dbg.KV("Donuts", 42) +} + +func ExampleDebugHandler_KVFunc() { + mux := http.NewServeMux() + dbg := Debugger(mux) + // Adds an count of page renders to /debug/. Note this example + // isn't concurrency-safe. + views := 0 + dbg.KVFunc("Debug pageviews", func() any { + views = views + 1 + return views + }) + dbg.KV("Donuts", 42) +} + +func ExampleDebugHandler_URL() { + mux := http.NewServeMux() + dbg := Debugger(mux) + // Links to the Tailscale website from /debug/. + dbg.URL("https://www.tailscale.com", "Homepage") +} + +func ExampleDebugHandler_Section() { + mux := http.NewServeMux() + dbg := Debugger(mux) + // Adds a section to /debug/ that dumps the HTTP request of the + // visitor. + dbg.Section(func(w io.Writer, r *http.Request) { + io.WriteString(w, "

Dump of your HTTP request

") + fmt.Fprintf(w, "%#v", r) + }) +} diff --git a/tsweb/promvarz/promvarz_test.go b/tsweb/promvarz/promvarz_test.go index 7f9b3396e..a3f4e66f1 100644 --- a/tsweb/promvarz/promvarz_test.go +++ b/tsweb/promvarz/promvarz_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause -package promvarz - -import ( - "expvar" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" - "github.com/prometheus/client_golang/prometheus/testutil" -) - -var ( - testVar1 = expvar.NewInt("gauge_promvarz_test_expvar") - testVar2 = promauto.NewGauge(prometheus.GaugeOpts{Name: "promvarz_test_native"}) -) - -func TestHandler(t *testing.T) { - testVar1.Set(42) - testVar2.Set(4242) - - svr := httptest.NewServer(http.HandlerFunc(Handler)) - defer svr.Close() - - want := ` - # TYPE promvarz_test_expvar gauge - promvarz_test_expvar 42 - # TYPE promvarz_test_native gauge - promvarz_test_native 4242 - ` - if err := testutil.ScrapeAndCompare(svr.URL, strings.NewReader(want), "promvarz_test_expvar", "promvarz_test_native"); err != nil { - t.Error(err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +package promvarz + +import ( + "expvar" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/prometheus/client_golang/prometheus/testutil" +) + +var ( + testVar1 = expvar.NewInt("gauge_promvarz_test_expvar") + testVar2 = promauto.NewGauge(prometheus.GaugeOpts{Name: "promvarz_test_native"}) +) + +func TestHandler(t *testing.T) { + testVar1.Set(42) + testVar2.Set(4242) + + svr := httptest.NewServer(http.HandlerFunc(Handler)) + defer svr.Close() + + want := ` + # TYPE promvarz_test_expvar gauge + promvarz_test_expvar 42 + # TYPE promvarz_test_native gauge + promvarz_test_native 4242 + ` + if err := testutil.ScrapeAndCompare(svr.URL, strings.NewReader(want), "promvarz_test_expvar", "promvarz_test_native"); err != nil { + t.Error(err) + } +} diff --git a/types/appctype/appconnector_test.go b/types/appctype/appconnector_test.go index 8aef135b4..390d1776a 100644 --- a/types/appctype/appconnector_test.go +++ b/types/appctype/appconnector_test.go @@ -1,78 +1,78 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package appctype - -import ( - "encoding/json" - "net/netip" - "strings" - "testing" - - "github.com/google/go-cmp/cmp" - "tailscale.com/tailcfg" - "tailscale.com/util/must" -) - -var golden = `{ - "dnat": { - "opaqueid1": { - "addrs": ["100.64.0.1", "fd7a:115c:a1e0::1"], - "to": ["example.org"], - "ip": ["*"] - } - }, - "sniProxy": { - "opaqueid2": { - "addrs": ["::"], - "ip": ["tcp:443"], - "allowedDomains": ["*"] - } - }, - "advertiseRoutes": true -}` - -func TestGolden(t *testing.T) { - wantDNAT := map[ConfigID]DNATConfig{"opaqueid1": { - Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, - To: []string{"example.org"}, - IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, - }} - - wantSNI := map[ConfigID]SNIProxyConfig{"opaqueid2": { - Addrs: []netip.Addr{netip.MustParseAddr("::")}, - IP: []tailcfg.ProtoPortRange{{Proto: 6, Ports: tailcfg.PortRange{First: 443, Last: 443}}}, - AllowedDomains: []string{"*"}, - }} - - var config AppConnectorConfig - if err := json.NewDecoder(strings.NewReader(golden)).Decode(&config); err != nil { - t.Fatalf("failed to decode golden config: %v", err) - } - - if !config.AdvertiseRoutes { - t.Fatalf("expected AdvertiseRoutes to be true, got false") - } - - assertEqual(t, "DNAT", config.DNAT, wantDNAT) - assertEqual(t, "SNI", config.SNIProxy, wantSNI) -} - -func TestRoundTrip(t *testing.T) { - var config AppConnectorConfig - must.Do(json.NewDecoder(strings.NewReader(golden)).Decode(&config)) - b := must.Get(json.Marshal(config)) - var config2 AppConnectorConfig - must.Do(json.Unmarshal(b, &config2)) - assertEqual(t, "DNAT", config.DNAT, config2.DNAT) -} - -func assertEqual(t *testing.T, name string, a, b any) { - var addrComparer = cmp.Comparer(func(a, b netip.Addr) bool { - return a.Compare(b) == 0 - }) - t.Helper() - if diff := cmp.Diff(a, b, addrComparer); diff != "" { - t.Fatalf("mismatch (-want +got):\n%s", diff) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package appctype + +import ( + "encoding/json" + "net/netip" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/tailcfg" + "tailscale.com/util/must" +) + +var golden = `{ + "dnat": { + "opaqueid1": { + "addrs": ["100.64.0.1", "fd7a:115c:a1e0::1"], + "to": ["example.org"], + "ip": ["*"] + } + }, + "sniProxy": { + "opaqueid2": { + "addrs": ["::"], + "ip": ["tcp:443"], + "allowedDomains": ["*"] + } + }, + "advertiseRoutes": true +}` + +func TestGolden(t *testing.T) { + wantDNAT := map[ConfigID]DNATConfig{"opaqueid1": { + Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}, + To: []string{"example.org"}, + IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}}, + }} + + wantSNI := map[ConfigID]SNIProxyConfig{"opaqueid2": { + Addrs: []netip.Addr{netip.MustParseAddr("::")}, + IP: []tailcfg.ProtoPortRange{{Proto: 6, Ports: tailcfg.PortRange{First: 443, Last: 443}}}, + AllowedDomains: []string{"*"}, + }} + + var config AppConnectorConfig + if err := json.NewDecoder(strings.NewReader(golden)).Decode(&config); err != nil { + t.Fatalf("failed to decode golden config: %v", err) + } + + if !config.AdvertiseRoutes { + t.Fatalf("expected AdvertiseRoutes to be true, got false") + } + + assertEqual(t, "DNAT", config.DNAT, wantDNAT) + assertEqual(t, "SNI", config.SNIProxy, wantSNI) +} + +func TestRoundTrip(t *testing.T) { + var config AppConnectorConfig + must.Do(json.NewDecoder(strings.NewReader(golden)).Decode(&config)) + b := must.Get(json.Marshal(config)) + var config2 AppConnectorConfig + must.Do(json.Unmarshal(b, &config2)) + assertEqual(t, "DNAT", config.DNAT, config2.DNAT) +} + +func assertEqual(t *testing.T, name string, a, b any) { + var addrComparer = cmp.Comparer(func(a, b netip.Addr) bool { + return a.Compare(b) == 0 + }) + t.Helper() + if diff := cmp.Diff(a, b, addrComparer); diff != "" { + t.Fatalf("mismatch (-want +got):\n%s", diff) + } +} diff --git a/types/dnstype/dnstype.go b/types/dnstype/dnstype.go index 6cc91c999..b7f5b9d02 100644 --- a/types/dnstype/dnstype.go +++ b/types/dnstype/dnstype.go @@ -1,68 +1,68 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package dnstype defines types for working with DNS. -package dnstype - -//go:generate go run tailscale.com/cmd/viewer --type=Resolver --clonefunc=true - -import ( - "net/netip" - "slices" -) - -// Resolver is the configuration for one DNS resolver. -type Resolver struct { - // Addr is the address of the DNS resolver, one of: - // - A plain IP address for a "classic" UDP+TCP DNS resolver. - // This is the common format as sent by the control plane. - // - An IP:port, for tests. - // - "https://resolver.com/path" for DNS over HTTPS; currently - // as of 2022-09-08 only used for certain well-known resolvers - // (see the publicdns package) for which the IP addresses to dial DoH are - // known ahead of time, so bootstrap DNS resolution is not required. - // - "http://node-address:port/path" for DNS over HTTP over WireGuard. This - // is implemented in the PeerAPI for exit nodes and app connectors. - // - [TODO] "tls://resolver.com" for DNS over TCP+TLS - Addr string `json:",omitempty"` - - // BootstrapResolution is an optional suggested resolution for the - // DoT/DoH resolver, if the resolver URL does not reference an IP - // address directly. - // BootstrapResolution may be empty, in which case clients should - // look up the DoT/DoH server using their local "classic" DNS - // resolver. - // - // As of 2022-09-08, BootstrapResolution is not yet used. - BootstrapResolution []netip.Addr `json:",omitempty"` -} - -// IPPort returns r.Addr as an IP address and port if either -// r.Addr is an IP address (the common case) or if r.Addr -// is an IP:port (as done in tests). -func (r *Resolver) IPPort() (ipp netip.AddrPort, ok bool) { - if r.Addr == "" || r.Addr[0] == 'h' || r.Addr[0] == 't' { - // Fast path to avoid ParseIP error allocation for obviously not IP - // cases. - return - } - if ip, err := netip.ParseAddr(r.Addr); err == nil { - return netip.AddrPortFrom(ip, 53), true - } - if ipp, err := netip.ParseAddrPort(r.Addr); err == nil { - return ipp, true - } - return -} - -// Equal reports whether r and other are equal. -func (r *Resolver) Equal(other *Resolver) bool { - if r == nil || other == nil { - return r == other - } - if r == other { - return true - } - - return r.Addr == other.Addr && slices.Equal(r.BootstrapResolution, other.BootstrapResolution) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package dnstype defines types for working with DNS. +package dnstype + +//go:generate go run tailscale.com/cmd/viewer --type=Resolver --clonefunc=true + +import ( + "net/netip" + "slices" +) + +// Resolver is the configuration for one DNS resolver. +type Resolver struct { + // Addr is the address of the DNS resolver, one of: + // - A plain IP address for a "classic" UDP+TCP DNS resolver. + // This is the common format as sent by the control plane. + // - An IP:port, for tests. + // - "https://resolver.com/path" for DNS over HTTPS; currently + // as of 2022-09-08 only used for certain well-known resolvers + // (see the publicdns package) for which the IP addresses to dial DoH are + // known ahead of time, so bootstrap DNS resolution is not required. + // - "http://node-address:port/path" for DNS over HTTP over WireGuard. This + // is implemented in the PeerAPI for exit nodes and app connectors. + // - [TODO] "tls://resolver.com" for DNS over TCP+TLS + Addr string `json:",omitempty"` + + // BootstrapResolution is an optional suggested resolution for the + // DoT/DoH resolver, if the resolver URL does not reference an IP + // address directly. + // BootstrapResolution may be empty, in which case clients should + // look up the DoT/DoH server using their local "classic" DNS + // resolver. + // + // As of 2022-09-08, BootstrapResolution is not yet used. + BootstrapResolution []netip.Addr `json:",omitempty"` +} + +// IPPort returns r.Addr as an IP address and port if either +// r.Addr is an IP address (the common case) or if r.Addr +// is an IP:port (as done in tests). +func (r *Resolver) IPPort() (ipp netip.AddrPort, ok bool) { + if r.Addr == "" || r.Addr[0] == 'h' || r.Addr[0] == 't' { + // Fast path to avoid ParseIP error allocation for obviously not IP + // cases. + return + } + if ip, err := netip.ParseAddr(r.Addr); err == nil { + return netip.AddrPortFrom(ip, 53), true + } + if ipp, err := netip.ParseAddrPort(r.Addr); err == nil { + return ipp, true + } + return +} + +// Equal reports whether r and other are equal. +func (r *Resolver) Equal(other *Resolver) bool { + if r == nil || other == nil { + return r == other + } + if r == other { + return true + } + + return r.Addr == other.Addr && slices.Equal(r.BootstrapResolution, other.BootstrapResolution) +} diff --git a/types/empty/message.go b/types/empty/message.go index 5ada7f402..dc8eb4cc2 100644 --- a/types/empty/message.go +++ b/types/empty/message.go @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package empty defines an empty struct type. -package empty - -// Message is an empty message. Its purpose is to be used as pointer -// type where nil and non-nil distinguish whether it's set. This is -// used instead of a bool when we want to marshal it as a JSON empty -// object (or null) for the future ability to add other fields, at -// which point callers would define a new struct and not use -// empty.Message. -type Message struct{} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package empty defines an empty struct type. +package empty + +// Message is an empty message. Its purpose is to be used as pointer +// type where nil and non-nil distinguish whether it's set. This is +// used instead of a bool when we want to marshal it as a JSON empty +// object (or null) for the future ability to add other fields, at +// which point callers would define a new struct and not use +// empty.Message. +type Message struct{} diff --git a/types/flagtype/flagtype.go b/types/flagtype/flagtype.go index c76b16353..be160dee8 100644 --- a/types/flagtype/flagtype.go +++ b/types/flagtype/flagtype.go @@ -1,45 +1,45 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package flagtype defines flag.Value types. -package flagtype - -import ( - "errors" - "flag" - "fmt" - "math" - "strconv" - "strings" -) - -type portValue struct{ n *uint16 } - -func PortValue(dst *uint16, defaultPort uint16) flag.Value { - *dst = defaultPort - return portValue{dst} -} - -func (p portValue) String() string { - if p.n == nil { - return "" - } - return fmt.Sprint(*p.n) -} -func (p portValue) Set(v string) error { - if v == "" { - return errors.New("can't be the empty string") - } - if strings.Contains(v, ":") { - return errors.New("expecting just a port number, without a colon") - } - n, err := strconv.ParseUint(v, 10, 64) // use 64 instead of 16 to return nicer error message - if err != nil { - return fmt.Errorf("not a valid number") - } - if n > math.MaxUint16 { - return errors.New("out of range for port number") - } - *p.n = uint16(n) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package flagtype defines flag.Value types. +package flagtype + +import ( + "errors" + "flag" + "fmt" + "math" + "strconv" + "strings" +) + +type portValue struct{ n *uint16 } + +func PortValue(dst *uint16, defaultPort uint16) flag.Value { + *dst = defaultPort + return portValue{dst} +} + +func (p portValue) String() string { + if p.n == nil { + return "" + } + return fmt.Sprint(*p.n) +} +func (p portValue) Set(v string) error { + if v == "" { + return errors.New("can't be the empty string") + } + if strings.Contains(v, ":") { + return errors.New("expecting just a port number, without a colon") + } + n, err := strconv.ParseUint(v, 10, 64) // use 64 instead of 16 to return nicer error message + if err != nil { + return fmt.Errorf("not a valid number") + } + if n > math.MaxUint16 { + return errors.New("out of range for port number") + } + *p.n = uint16(n) + return nil +} diff --git a/types/ipproto/ipproto.go b/types/ipproto/ipproto.go index 97fc4f3dd..b5333eb56 100644 --- a/types/ipproto/ipproto.go +++ b/types/ipproto/ipproto.go @@ -1,199 +1,199 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package ipproto contains IP Protocol constants. -package ipproto - -import ( - "fmt" - "strconv" - - "tailscale.com/util/nocasemaps" - "tailscale.com/util/vizerror" -) - -// Version describes the IP address version. -type Version uint8 - -// Valid Version values. -const ( - Version4 = 4 - Version6 = 6 -) - -func (p Version) String() string { - switch p { - case Version4: - return "IPv4" - case Version6: - return "IPv6" - default: - return fmt.Sprintf("Version-%d", int(p)) - } -} - -// Proto is an IP subprotocol as defined by the IANA protocol -// numbers list -// (https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml), -// or the special values Unknown or Fragment. -type Proto uint8 - -const ( - // Unknown represents an unknown or unsupported protocol; it's - // deliberately the zero value. Strictly speaking the zero - // value is IPv6 hop-by-hop extensions, but we don't support - // those, so this is still technically correct. - Unknown Proto = 0x00 - - // Values from the IANA registry. - ICMPv4 Proto = 0x01 - IGMP Proto = 0x02 - ICMPv6 Proto = 0x3a - TCP Proto = 0x06 - UDP Proto = 0x11 - DCCP Proto = 0x21 - GRE Proto = 0x2f - SCTP Proto = 0x84 - - // TSMP is the Tailscale Message Protocol (our ICMP-ish - // thing), an IP protocol used only between Tailscale nodes - // (still encrypted by WireGuard) that communicates why things - // failed, etc. - // - // Proto number 99 is reserved for "any private encryption - // scheme". We never accept these from the host OS stack nor - // send them to the host network stack. It's only used between - // nodes. - TSMP Proto = 99 - - // Fragment represents any non-first IP fragment, for which we - // don't have the sub-protocol header (and therefore can't - // figure out what the sub-protocol is). - // - // 0xFF is reserved in the IANA registry, so we steal it for - // internal use. - Fragment Proto = 0xFF -) - -// Deprecated: use MarshalText instead. -func (p Proto) String() string { - switch p { - case Unknown: - return "Unknown" - case Fragment: - return "Frag" - case ICMPv4: - return "ICMPv4" - case IGMP: - return "IGMP" - case ICMPv6: - return "ICMPv6" - case UDP: - return "UDP" - case TCP: - return "TCP" - case SCTP: - return "SCTP" - case TSMP: - return "TSMP" - case GRE: - return "GRE" - case DCCP: - return "DCCP" - default: - return fmt.Sprintf("IPProto-%d", int(p)) - } -} - -// Prefer names from -// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml -// unless otherwise noted. -var ( - // preferredNames is the set of protocol names that re produced by - // MarshalText, and are the preferred representation. - preferredNames = map[Proto]string{ - 51: "ah", - DCCP: "dccp", - 8: "egp", - 50: "esp", - 47: "gre", - ICMPv4: "icmp", - IGMP: "igmp", - 9: "igp", - 4: "ipv4", - ICMPv6: "ipv6-icmp", - SCTP: "sctp", - TCP: "tcp", - UDP: "udp", - } - - // acceptedNames is the set of protocol names that are accepted by - // UnmarshalText. - acceptedNames = map[string]Proto{ - "ah": 51, - "dccp": DCCP, - "egp": 8, - "esp": 50, - "gre": 47, - "icmp": ICMPv4, - "icmpv4": ICMPv4, - "icmpv6": ICMPv6, - "igmp": IGMP, - "igp": 9, - "ip-in-ip": 4, // IANA says "ipv4"; Wikipedia/popular use says "ip-in-ip" - "ipv4": 4, - "ipv6-icmp": ICMPv6, - "sctp": SCTP, - "tcp": TCP, - "tsmp": TSMP, - "udp": UDP, - } -) - -// UnmarshalText implements encoding.TextUnmarshaler. If the input is empty, p -// is set to 0. If an error occurs, p is unchanged. -func (p *Proto) UnmarshalText(b []byte) error { - if len(b) == 0 { - *p = 0 - return nil - } - - if u, err := strconv.ParseUint(string(b), 10, 8); err == nil { - *p = Proto(u) - return nil - } - - if newP, ok := nocasemaps.GetOk(acceptedNames, string(b)); ok { - *p = newP - return nil - } - - return vizerror.Errorf("proto name %q not known; use protocol number 0-255", b) -} - -// MarshalText implements encoding.TextMarshaler. -func (p Proto) MarshalText() ([]byte, error) { - if s, ok := preferredNames[p]; ok { - return []byte(s), nil - } - return []byte(strconv.Itoa(int(p))), nil -} - -// MarshalJSON implements json.Marshaler. -func (p Proto) MarshalJSON() ([]byte, error) { - return []byte(strconv.Itoa(int(p))), nil -} - -// UnmarshalJSON implements json.Unmarshaler. If the input is empty, p is set to -// 0. If an error occurs, p is unchanged. The input must be a JSON number or an -// accepted string name. -func (p *Proto) UnmarshalJSON(b []byte) error { - if len(b) == 0 { - *p = 0 - return nil - } - if b[0] == '"' { - b = b[1 : len(b)-1] - } - return p.UnmarshalText(b) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ipproto contains IP Protocol constants. +package ipproto + +import ( + "fmt" + "strconv" + + "tailscale.com/util/nocasemaps" + "tailscale.com/util/vizerror" +) + +// Version describes the IP address version. +type Version uint8 + +// Valid Version values. +const ( + Version4 = 4 + Version6 = 6 +) + +func (p Version) String() string { + switch p { + case Version4: + return "IPv4" + case Version6: + return "IPv6" + default: + return fmt.Sprintf("Version-%d", int(p)) + } +} + +// Proto is an IP subprotocol as defined by the IANA protocol +// numbers list +// (https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml), +// or the special values Unknown or Fragment. +type Proto uint8 + +const ( + // Unknown represents an unknown or unsupported protocol; it's + // deliberately the zero value. Strictly speaking the zero + // value is IPv6 hop-by-hop extensions, but we don't support + // those, so this is still technically correct. + Unknown Proto = 0x00 + + // Values from the IANA registry. + ICMPv4 Proto = 0x01 + IGMP Proto = 0x02 + ICMPv6 Proto = 0x3a + TCP Proto = 0x06 + UDP Proto = 0x11 + DCCP Proto = 0x21 + GRE Proto = 0x2f + SCTP Proto = 0x84 + + // TSMP is the Tailscale Message Protocol (our ICMP-ish + // thing), an IP protocol used only between Tailscale nodes + // (still encrypted by WireGuard) that communicates why things + // failed, etc. + // + // Proto number 99 is reserved for "any private encryption + // scheme". We never accept these from the host OS stack nor + // send them to the host network stack. It's only used between + // nodes. + TSMP Proto = 99 + + // Fragment represents any non-first IP fragment, for which we + // don't have the sub-protocol header (and therefore can't + // figure out what the sub-protocol is). + // + // 0xFF is reserved in the IANA registry, so we steal it for + // internal use. + Fragment Proto = 0xFF +) + +// Deprecated: use MarshalText instead. +func (p Proto) String() string { + switch p { + case Unknown: + return "Unknown" + case Fragment: + return "Frag" + case ICMPv4: + return "ICMPv4" + case IGMP: + return "IGMP" + case ICMPv6: + return "ICMPv6" + case UDP: + return "UDP" + case TCP: + return "TCP" + case SCTP: + return "SCTP" + case TSMP: + return "TSMP" + case GRE: + return "GRE" + case DCCP: + return "DCCP" + default: + return fmt.Sprintf("IPProto-%d", int(p)) + } +} + +// Prefer names from +// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml +// unless otherwise noted. +var ( + // preferredNames is the set of protocol names that re produced by + // MarshalText, and are the preferred representation. + preferredNames = map[Proto]string{ + 51: "ah", + DCCP: "dccp", + 8: "egp", + 50: "esp", + 47: "gre", + ICMPv4: "icmp", + IGMP: "igmp", + 9: "igp", + 4: "ipv4", + ICMPv6: "ipv6-icmp", + SCTP: "sctp", + TCP: "tcp", + UDP: "udp", + } + + // acceptedNames is the set of protocol names that are accepted by + // UnmarshalText. + acceptedNames = map[string]Proto{ + "ah": 51, + "dccp": DCCP, + "egp": 8, + "esp": 50, + "gre": 47, + "icmp": ICMPv4, + "icmpv4": ICMPv4, + "icmpv6": ICMPv6, + "igmp": IGMP, + "igp": 9, + "ip-in-ip": 4, // IANA says "ipv4"; Wikipedia/popular use says "ip-in-ip" + "ipv4": 4, + "ipv6-icmp": ICMPv6, + "sctp": SCTP, + "tcp": TCP, + "tsmp": TSMP, + "udp": UDP, + } +) + +// UnmarshalText implements encoding.TextUnmarshaler. If the input is empty, p +// is set to 0. If an error occurs, p is unchanged. +func (p *Proto) UnmarshalText(b []byte) error { + if len(b) == 0 { + *p = 0 + return nil + } + + if u, err := strconv.ParseUint(string(b), 10, 8); err == nil { + *p = Proto(u) + return nil + } + + if newP, ok := nocasemaps.GetOk(acceptedNames, string(b)); ok { + *p = newP + return nil + } + + return vizerror.Errorf("proto name %q not known; use protocol number 0-255", b) +} + +// MarshalText implements encoding.TextMarshaler. +func (p Proto) MarshalText() ([]byte, error) { + if s, ok := preferredNames[p]; ok { + return []byte(s), nil + } + return []byte(strconv.Itoa(int(p))), nil +} + +// MarshalJSON implements json.Marshaler. +func (p Proto) MarshalJSON() ([]byte, error) { + return []byte(strconv.Itoa(int(p))), nil +} + +// UnmarshalJSON implements json.Unmarshaler. If the input is empty, p is set to +// 0. If an error occurs, p is unchanged. The input must be a JSON number or an +// accepted string name. +func (p *Proto) UnmarshalJSON(b []byte) error { + if len(b) == 0 { + *p = 0 + return nil + } + if b[0] == '"' { + b = b[1 : len(b)-1] + } + return p.UnmarshalText(b) +} diff --git a/types/key/chal.go b/types/key/chal.go index da15dd1f8..742ac5479 100644 --- a/types/key/chal.go +++ b/types/key/chal.go @@ -1,91 +1,91 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import ( - "errors" - - "go4.org/mem" - "tailscale.com/types/structs" -) - -const ( - // chalPublicHexPrefix is the prefix used to identify a - // hex-encoded challenge public key. - // - // This prefix is used in the control protocol, so cannot be - // changed. - chalPublicHexPrefix = "chalpub:" -) - -// ChallengePrivate is a challenge key, used to test whether clients control a -// key they want to prove ownership of. -// -// A ChallengePrivate is ephemeral and not serialized to the disk or network. -type ChallengePrivate struct { - _ structs.Incomparable // because == isn't constant-time - k [32]byte -} - -// NewChallenge creates and returns a new node private key. -func NewChallenge() ChallengePrivate { - return ChallengePrivate(NewNode()) -} - -// Public returns the ChallengePublic for k. -// Panics if ChallengePublic is zero. -func (k ChallengePrivate) Public() ChallengePublic { - pub := NodePrivate(k).Public() - return ChallengePublic(pub) -} - -// MarshalText implements encoding.TextMarshaler, but by returning an error. -// It shouldn't need to be marshalled anywhere. -func (k ChallengePrivate) MarshalText() ([]byte, error) { - return nil, errors.New("refusing to marshal") -} - -// SealToChallenge is like SealTo, but for a ChallengePublic. -func (k NodePrivate) SealToChallenge(p ChallengePublic, cleartext []byte) (ciphertext []byte) { - return k.SealTo(NodePublic(p), cleartext) -} - -// OpenFrom opens the NaCl box ciphertext, which must be a value -// created by NodePrivate.SealToChallenge, and returns the inner cleartext if -// ciphertext is a valid box from p to k. -func (k ChallengePrivate) OpenFrom(p NodePublic, ciphertext []byte) (cleartext []byte, ok bool) { - return NodePrivate(k).OpenFrom(p, ciphertext) -} - -// ChallengePublic is the public portion of a ChallengePrivate. -type ChallengePublic struct { - k [32]byte -} - -// String returns the output of MarshalText as a string. -func (k ChallengePublic) String() string { - bs, err := k.MarshalText() - if err != nil { - panic(err) - } - return string(bs) -} - -// AppendText implements encoding.TextAppender. -func (k ChallengePublic) AppendText(b []byte) ([]byte, error) { - return appendHexKey(b, chalPublicHexPrefix, k.k[:]), nil -} - -// MarshalText implements encoding.TextMarshaler. -func (k ChallengePublic) MarshalText() ([]byte, error) { - return k.AppendText(nil) -} - -// UnmarshalText implements encoding.TextUnmarshaler. -func (k *ChallengePublic) UnmarshalText(b []byte) error { - return parseHex(k.k[:], mem.B(b), mem.S(chalPublicHexPrefix)) -} - -// IsZero reports whether k is the zero value. -func (k ChallengePublic) IsZero() bool { return k == ChallengePublic{} } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "errors" + + "go4.org/mem" + "tailscale.com/types/structs" +) + +const ( + // chalPublicHexPrefix is the prefix used to identify a + // hex-encoded challenge public key. + // + // This prefix is used in the control protocol, so cannot be + // changed. + chalPublicHexPrefix = "chalpub:" +) + +// ChallengePrivate is a challenge key, used to test whether clients control a +// key they want to prove ownership of. +// +// A ChallengePrivate is ephemeral and not serialized to the disk or network. +type ChallengePrivate struct { + _ structs.Incomparable // because == isn't constant-time + k [32]byte +} + +// NewChallenge creates and returns a new node private key. +func NewChallenge() ChallengePrivate { + return ChallengePrivate(NewNode()) +} + +// Public returns the ChallengePublic for k. +// Panics if ChallengePublic is zero. +func (k ChallengePrivate) Public() ChallengePublic { + pub := NodePrivate(k).Public() + return ChallengePublic(pub) +} + +// MarshalText implements encoding.TextMarshaler, but by returning an error. +// It shouldn't need to be marshalled anywhere. +func (k ChallengePrivate) MarshalText() ([]byte, error) { + return nil, errors.New("refusing to marshal") +} + +// SealToChallenge is like SealTo, but for a ChallengePublic. +func (k NodePrivate) SealToChallenge(p ChallengePublic, cleartext []byte) (ciphertext []byte) { + return k.SealTo(NodePublic(p), cleartext) +} + +// OpenFrom opens the NaCl box ciphertext, which must be a value +// created by NodePrivate.SealToChallenge, and returns the inner cleartext if +// ciphertext is a valid box from p to k. +func (k ChallengePrivate) OpenFrom(p NodePublic, ciphertext []byte) (cleartext []byte, ok bool) { + return NodePrivate(k).OpenFrom(p, ciphertext) +} + +// ChallengePublic is the public portion of a ChallengePrivate. +type ChallengePublic struct { + k [32]byte +} + +// String returns the output of MarshalText as a string. +func (k ChallengePublic) String() string { + bs, err := k.MarshalText() + if err != nil { + panic(err) + } + return string(bs) +} + +// AppendText implements encoding.TextAppender. +func (k ChallengePublic) AppendText(b []byte) ([]byte, error) { + return appendHexKey(b, chalPublicHexPrefix, k.k[:]), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (k ChallengePublic) MarshalText() ([]byte, error) { + return k.AppendText(nil) +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (k *ChallengePublic) UnmarshalText(b []byte) error { + return parseHex(k.k[:], mem.B(b), mem.S(chalPublicHexPrefix)) +} + +// IsZero reports whether k is the zero value. +func (k ChallengePublic) IsZero() bool { return k == ChallengePublic{} } diff --git a/types/key/control.go b/types/key/control.go index a84359771..96021249b 100644 --- a/types/key/control.go +++ b/types/key/control.go @@ -1,68 +1,68 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import "encoding/json" - -// ControlPrivate is a Tailscale control plane private key. -// -// It is functionally equivalent to a MachinePrivate, but serializes -// to JSON as a byte array rather than a typed string, because our -// control plane database stores the key that way. -// -// Deprecated: this type should only be used in Tailscale's control -// plane, where existing database serializations require this -// less-good serialization format to persist. Other control plane -// implementations can use MachinePrivate with no downsides. -type ControlPrivate struct { - mkey MachinePrivate // unexported so we can limit the API surface to only exactly what we need -} - -// NewControl generates and returns a new control plane private key. -func NewControl() ControlPrivate { - return ControlPrivate{NewMachine()} -} - -// IsZero reports whether k is the zero value. -func (k ControlPrivate) IsZero() bool { - return k.mkey.IsZero() -} - -// Public returns the MachinePublic for k. -// Panics if ControlPrivate is zero. -func (k ControlPrivate) Public() MachinePublic { - return k.mkey.Public() -} - -// MarshalJSON implements json.Marshaler. -func (k ControlPrivate) MarshalJSON() ([]byte, error) { - return json.Marshal(k.mkey.k) -} - -// UnmarshalJSON implements json.Unmarshaler. -func (k *ControlPrivate) UnmarshalJSON(bs []byte) error { - return json.Unmarshal(bs, &k.mkey.k) -} - -// SealTo wraps cleartext into a NaCl box (see -// golang.org/x/crypto/nacl) to p, authenticated from k, using a -// random nonce. -// -// The returned ciphertext is a 24-byte nonce concatenated with the -// box value. -func (k ControlPrivate) SealTo(p MachinePublic, cleartext []byte) (ciphertext []byte) { - return k.mkey.SealTo(p, cleartext) -} - -// SharedKey returns the precomputed Nacl box shared key between k and p. -func (k ControlPrivate) SharedKey(p MachinePublic) MachinePrecomputedSharedKey { - return k.mkey.SharedKey(p) -} - -// OpenFrom opens the NaCl box ciphertext, which must be a value -// created by SealTo, and returns the inner cleartext if ciphertext is -// a valid box from p to k. -func (k ControlPrivate) OpenFrom(p MachinePublic, ciphertext []byte) (cleartext []byte, ok bool) { - return k.mkey.OpenFrom(p, ciphertext) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import "encoding/json" + +// ControlPrivate is a Tailscale control plane private key. +// +// It is functionally equivalent to a MachinePrivate, but serializes +// to JSON as a byte array rather than a typed string, because our +// control plane database stores the key that way. +// +// Deprecated: this type should only be used in Tailscale's control +// plane, where existing database serializations require this +// less-good serialization format to persist. Other control plane +// implementations can use MachinePrivate with no downsides. +type ControlPrivate struct { + mkey MachinePrivate // unexported so we can limit the API surface to only exactly what we need +} + +// NewControl generates and returns a new control plane private key. +func NewControl() ControlPrivate { + return ControlPrivate{NewMachine()} +} + +// IsZero reports whether k is the zero value. +func (k ControlPrivate) IsZero() bool { + return k.mkey.IsZero() +} + +// Public returns the MachinePublic for k. +// Panics if ControlPrivate is zero. +func (k ControlPrivate) Public() MachinePublic { + return k.mkey.Public() +} + +// MarshalJSON implements json.Marshaler. +func (k ControlPrivate) MarshalJSON() ([]byte, error) { + return json.Marshal(k.mkey.k) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (k *ControlPrivate) UnmarshalJSON(bs []byte) error { + return json.Unmarshal(bs, &k.mkey.k) +} + +// SealTo wraps cleartext into a NaCl box (see +// golang.org/x/crypto/nacl) to p, authenticated from k, using a +// random nonce. +// +// The returned ciphertext is a 24-byte nonce concatenated with the +// box value. +func (k ControlPrivate) SealTo(p MachinePublic, cleartext []byte) (ciphertext []byte) { + return k.mkey.SealTo(p, cleartext) +} + +// SharedKey returns the precomputed Nacl box shared key between k and p. +func (k ControlPrivate) SharedKey(p MachinePublic) MachinePrecomputedSharedKey { + return k.mkey.SharedKey(p) +} + +// OpenFrom opens the NaCl box ciphertext, which must be a value +// created by SealTo, and returns the inner cleartext if ciphertext is +// a valid box from p to k. +func (k ControlPrivate) OpenFrom(p MachinePublic, ciphertext []byte) (cleartext []byte, ok bool) { + return k.mkey.OpenFrom(p, ciphertext) +} diff --git a/types/key/control_test.go b/types/key/control_test.go index 06e0f36d5..a98a586f3 100644 --- a/types/key/control_test.go +++ b/types/key/control_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import ( - "encoding/json" - "testing" -) - -func TestControlKey(t *testing.T) { - serialized := `{"PrivateKey":[36,132,249,6,73,141,249,49,9,96,49,60,240,217,253,57,3,69,248,64,178,62,121,73,121,88,115,218,130,145,68,254]}` - want := ControlPrivate{ - MachinePrivate{ - k: [32]byte{36, 132, 249, 6, 73, 141, 249, 49, 9, 96, 49, 60, 240, 217, 253, 57, 3, 69, 248, 64, 178, 62, 121, 73, 121, 88, 115, 218, 130, 145, 68, 254}, - }, - } - - var got struct { - PrivateKey ControlPrivate - } - if err := json.Unmarshal([]byte(serialized), &got); err != nil { - t.Fatalf("decoding serialized ControlPrivate: %v", err) - } - - if !got.PrivateKey.mkey.Equal(want.mkey) { - t.Fatalf("Serialized ControlPrivate didn't deserialize as expected, got %v want %v", got.PrivateKey, want) - } - - bs, err := json.Marshal(got) - if err != nil { - t.Fatalf("json reserialization of ControlPrivate failed: %v", err) - } - - if got, want := string(bs), serialized; got != want { - t.Fatalf("ControlPrivate didn't round-trip, got %q want %q", got, want) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "encoding/json" + "testing" +) + +func TestControlKey(t *testing.T) { + serialized := `{"PrivateKey":[36,132,249,6,73,141,249,49,9,96,49,60,240,217,253,57,3,69,248,64,178,62,121,73,121,88,115,218,130,145,68,254]}` + want := ControlPrivate{ + MachinePrivate{ + k: [32]byte{36, 132, 249, 6, 73, 141, 249, 49, 9, 96, 49, 60, 240, 217, 253, 57, 3, 69, 248, 64, 178, 62, 121, 73, 121, 88, 115, 218, 130, 145, 68, 254}, + }, + } + + var got struct { + PrivateKey ControlPrivate + } + if err := json.Unmarshal([]byte(serialized), &got); err != nil { + t.Fatalf("decoding serialized ControlPrivate: %v", err) + } + + if !got.PrivateKey.mkey.Equal(want.mkey) { + t.Fatalf("Serialized ControlPrivate didn't deserialize as expected, got %v want %v", got.PrivateKey, want) + } + + bs, err := json.Marshal(got) + if err != nil { + t.Fatalf("json reserialization of ControlPrivate failed: %v", err) + } + + if got, want := string(bs), serialized; got != want { + t.Fatalf("ControlPrivate didn't round-trip, got %q want %q", got, want) + } +} diff --git a/types/key/disco_test.go b/types/key/disco_test.go index c9d60c828..c62c13cbf 100644 --- a/types/key/disco_test.go +++ b/types/key/disco_test.go @@ -1,83 +1,83 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import ( - "bytes" - "encoding/json" - "testing" -) - -func TestDiscoKey(t *testing.T) { - k := NewDisco() - if k.IsZero() { - t.Fatal("DiscoPrivate should not be zero") - } - - p := k.Public() - if p.IsZero() { - t.Fatal("DiscoPublic should not be zero") - } - - bs, err := p.MarshalText() - if err != nil { - t.Fatal(err) - } - if !bytes.HasPrefix(bs, []byte("discokey:")) { - t.Fatalf("serialization of public discokey %s has wrong prefix", p) - } - - z := DiscoPublic{} - if !z.IsZero() { - t.Fatal("IsZero(DiscoPublic{}) is false") - } - if s := z.ShortString(); s != "" { - t.Fatalf("DiscoPublic{}.ShortString() is %q, want \"\"", s) - } -} - -func TestDiscoSerialization(t *testing.T) { - serialized := `{ - "Pub":"discokey:50d20b455ecf12bc453f83c2cfdb2a24925d06cf2598dcaa54e91af82ce9f765" - }` - - pub := DiscoPublic{ - k: [32]uint8{ - 0x50, 0xd2, 0xb, 0x45, 0x5e, 0xcf, 0x12, 0xbc, 0x45, 0x3f, 0x83, - 0xc2, 0xcf, 0xdb, 0x2a, 0x24, 0x92, 0x5d, 0x6, 0xcf, 0x25, 0x98, - 0xdc, 0xaa, 0x54, 0xe9, 0x1a, 0xf8, 0x2c, 0xe9, 0xf7, 0x65, - }, - } - - type key struct { - Pub DiscoPublic - } - - var a key - if err := json.Unmarshal([]byte(serialized), &a); err != nil { - t.Fatal(err) - } - if a.Pub != pub { - t.Errorf("wrong deserialization of public key, got %#v want %#v", a.Pub, pub) - } - - bs, err := json.MarshalIndent(a, "", " ") - if err != nil { - t.Fatal(err) - } - - var b bytes.Buffer - json.Indent(&b, []byte(serialized), "", " ") - if got, want := string(bs), b.String(); got != want { - t.Error("json serialization doesn't roundtrip") - } -} - -func TestDiscoShared(t *testing.T) { - k1, k2 := NewDisco(), NewDisco() - s1, s2 := k1.Shared(k2.Public()), k2.Shared(k1.Public()) - if !s1.Equal(s2) { - t.Error("k1.Shared(k2) != k2.Shared(k1)") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "bytes" + "encoding/json" + "testing" +) + +func TestDiscoKey(t *testing.T) { + k := NewDisco() + if k.IsZero() { + t.Fatal("DiscoPrivate should not be zero") + } + + p := k.Public() + if p.IsZero() { + t.Fatal("DiscoPublic should not be zero") + } + + bs, err := p.MarshalText() + if err != nil { + t.Fatal(err) + } + if !bytes.HasPrefix(bs, []byte("discokey:")) { + t.Fatalf("serialization of public discokey %s has wrong prefix", p) + } + + z := DiscoPublic{} + if !z.IsZero() { + t.Fatal("IsZero(DiscoPublic{}) is false") + } + if s := z.ShortString(); s != "" { + t.Fatalf("DiscoPublic{}.ShortString() is %q, want \"\"", s) + } +} + +func TestDiscoSerialization(t *testing.T) { + serialized := `{ + "Pub":"discokey:50d20b455ecf12bc453f83c2cfdb2a24925d06cf2598dcaa54e91af82ce9f765" + }` + + pub := DiscoPublic{ + k: [32]uint8{ + 0x50, 0xd2, 0xb, 0x45, 0x5e, 0xcf, 0x12, 0xbc, 0x45, 0x3f, 0x83, + 0xc2, 0xcf, 0xdb, 0x2a, 0x24, 0x92, 0x5d, 0x6, 0xcf, 0x25, 0x98, + 0xdc, 0xaa, 0x54, 0xe9, 0x1a, 0xf8, 0x2c, 0xe9, 0xf7, 0x65, + }, + } + + type key struct { + Pub DiscoPublic + } + + var a key + if err := json.Unmarshal([]byte(serialized), &a); err != nil { + t.Fatal(err) + } + if a.Pub != pub { + t.Errorf("wrong deserialization of public key, got %#v want %#v", a.Pub, pub) + } + + bs, err := json.MarshalIndent(a, "", " ") + if err != nil { + t.Fatal(err) + } + + var b bytes.Buffer + json.Indent(&b, []byte(serialized), "", " ") + if got, want := string(bs), b.String(); got != want { + t.Error("json serialization doesn't roundtrip") + } +} + +func TestDiscoShared(t *testing.T) { + k1, k2 := NewDisco(), NewDisco() + s1, s2 := k1.Shared(k2.Public()), k2.Shared(k1.Public()) + if !s1.Equal(s2) { + t.Error("k1.Shared(k2) != k2.Shared(k1)") + } +} diff --git a/types/key/machine.go b/types/key/machine.go index 0dc02574c..a05f3cc1f 100644 --- a/types/key/machine.go +++ b/types/key/machine.go @@ -1,264 +1,264 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import ( - "bytes" - "crypto/subtle" - "encoding/hex" - - "go4.org/mem" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/nacl/box" - "tailscale.com/types/structs" -) - -const ( - // machinePrivateHexPrefix is the prefix used to identify a - // hex-encoded machine private key. - // - // This prefix name is a little unfortunate, in that it comes from - // WireGuard's own key types. Unfortunately we're stuck with it for - // machine keys, because we serialize them to disk with this prefix. - machinePrivateHexPrefix = "privkey:" - - // machinePublicHexPrefix is the prefix used to identify a - // hex-encoded machine public key. - // - // This prefix is used in the control protocol, so cannot be - // changed. - machinePublicHexPrefix = "mkey:" -) - -// MachinePrivate is a machine key, used for communication with the -// Tailscale coordination server. -type MachinePrivate struct { - _ structs.Incomparable // == isn't constant-time - k [32]byte -} - -// NewMachine creates and returns a new machine private key. -func NewMachine() MachinePrivate { - var ret MachinePrivate - rand(ret.k[:]) - clamp25519Private(ret.k[:]) - return ret -} - -// IsZero reports whether k is the zero value. -func (k MachinePrivate) IsZero() bool { - return k.Equal(MachinePrivate{}) -} - -// Equal reports whether k and other are the same key. -func (k MachinePrivate) Equal(other MachinePrivate) bool { - return subtle.ConstantTimeCompare(k.k[:], other.k[:]) == 1 -} - -// Public returns the MachinePublic for k. -// Panics if MachinePrivate is zero. -func (k MachinePrivate) Public() MachinePublic { - if k.IsZero() { - panic("can't take the public key of a zero MachinePrivate") - } - var ret MachinePublic - curve25519.ScalarBaseMult(&ret.k, &k.k) - return ret -} - -// AppendText implements encoding.TextAppender. -func (k MachinePrivate) AppendText(b []byte) ([]byte, error) { - return appendHexKey(b, machinePrivateHexPrefix, k.k[:]), nil -} - -// MarshalText implements encoding.TextMarshaler. -func (k MachinePrivate) MarshalText() ([]byte, error) { - return k.AppendText(nil) -} - -// MarshalText implements encoding.TextUnmarshaler. -func (k *MachinePrivate) UnmarshalText(b []byte) error { - return parseHex(k.k[:], mem.B(b), mem.S(machinePrivateHexPrefix)) -} - -// UntypedBytes returns k, encoded as an untyped 64-character hex -// string. -// -// Deprecated: this function is risky to use, because it produces -// serialized values that do not identify themselves as a -// MachinePrivate, allowing other code to potentially parse it back in -// as the wrong key type. For new uses that don't require this -// specific raw byte serialization, please use -// MarshalText/UnmarshalText. -func (k MachinePrivate) UntypedBytes() []byte { - return bytes.Clone(k.k[:]) -} - -// SealTo wraps cleartext into a NaCl box (see -// golang.org/x/crypto/nacl) to p, authenticated from k, using a -// random nonce. -// -// The returned ciphertext is a 24-byte nonce concatenated with the -// box value. -func (k MachinePrivate) SealTo(p MachinePublic, cleartext []byte) (ciphertext []byte) { - if k.IsZero() || p.IsZero() { - panic("can't seal with zero keys") - } - var nonce [24]byte - rand(nonce[:]) - return box.Seal(nonce[:], cleartext, &nonce, &p.k, &k.k) -} - -// SharedKey returns the precomputed Nacl box shared key between k and p. -func (k MachinePrivate) SharedKey(p MachinePublic) MachinePrecomputedSharedKey { - var shared MachinePrecomputedSharedKey - box.Precompute(&shared.k, &p.k, &k.k) - return shared -} - -// MachinePrecomputedSharedKey is a precomputed shared NaCl box shared key. -type MachinePrecomputedSharedKey struct { - k [32]byte -} - -// Seal wraps cleartext into a NaCl box (see -// golang.org/x/crypto/nacl) using the shared key k as generated -// by MachinePrivate.SharedKey. -// -// The returned ciphertext is a 24-byte nonce concatenated with the -// box value. -func (k MachinePrecomputedSharedKey) Seal(cleartext []byte) (ciphertext []byte) { - if k == (MachinePrecomputedSharedKey{}) { - panic("can't seal with zero keys") - } - var nonce [24]byte - rand(nonce[:]) - return box.SealAfterPrecomputation(nonce[:], cleartext, &nonce, &k.k) -} - -// Open opens the NaCl box ciphertext, which must be a value created by -// MachinePrecomputedSharedKey.Seal or MachinePrivate.SealTo, and returns the -// inner cleartext if ciphertext is a valid box for the shared key k. -func (k MachinePrecomputedSharedKey) Open(ciphertext []byte) (cleartext []byte, ok bool) { - if k == (MachinePrecomputedSharedKey{}) { - panic("can't open with zero keys") - } - if len(ciphertext) < 24 { - return nil, false - } - var nonce [24]byte - copy(nonce[:], ciphertext) - return box.OpenAfterPrecomputation(nil, ciphertext[len(nonce):], &nonce, &k.k) -} - -// OpenFrom opens the NaCl box ciphertext, which must be a value -// created by SealTo, and returns the inner cleartext if ciphertext is -// a valid box from p to k. -func (k MachinePrivate) OpenFrom(p MachinePublic, ciphertext []byte) (cleartext []byte, ok bool) { - if k.IsZero() || p.IsZero() { - panic("can't open with zero keys") - } - if len(ciphertext) < 24 { - return nil, false - } - var nonce [24]byte - copy(nonce[:], ciphertext) - return box.Open(nil, ciphertext[len(nonce):], &nonce, &p.k, &k.k) -} - -// MachinePublic is the public portion of a a MachinePrivate. -type MachinePublic struct { - k [32]byte -} - -// MachinePublicFromRaw32 parses a 32-byte raw value as a MachinePublic. -// -// This should be used only when deserializing a MachinePublic from a -// binary protocol. -func MachinePublicFromRaw32(raw mem.RO) MachinePublic { - if raw.Len() != 32 { - panic("input has wrong size") - } - var ret MachinePublic - raw.Copy(ret.k[:]) - return ret -} - -// ParseMachinePublicUntyped parses an untyped 64-character hex value -// as a MachinePublic. -// -// Deprecated: this function is risky to use, because it cannot verify -// that the hex string was intended to be a MachinePublic. This can -// lead to accidentally decoding one type of key as another. For new -// uses that don't require backwards compatibility with the untyped -// string format, please use MarshalText/UnmarshalText. -func ParseMachinePublicUntyped(raw mem.RO) (MachinePublic, error) { - var ret MachinePublic - if err := parseHex(ret.k[:], raw, mem.B(nil)); err != nil { - return MachinePublic{}, err - } - return ret, nil -} - -// IsZero reports whether k is the zero value. -func (k MachinePublic) IsZero() bool { - return k == MachinePublic{} -} - -// ShortString returns the Tailscale conventional debug representation -// of a public key: the first five base64 digits of the key, in square -// brackets. -func (k MachinePublic) ShortString() string { - return debug32(k.k) -} - -// UntypedHexString returns k, encoded as an untyped 64-character hex -// string. -// -// Deprecated: this function is risky to use, because it produces -// serialized values that do not identify themselves as a -// MachinePublic, allowing other code to potentially parse it back in -// as the wrong key type. For new uses that don't require backwards -// compatibility with the untyped string format, please use -// MarshalText/UnmarshalText. -func (k MachinePublic) UntypedHexString() string { - return hex.EncodeToString(k.k[:]) -} - -// UntypedBytes returns k, encoded as an untyped 64-character hex -// string. -// -// Deprecated: this function is risky to use, because it produces -// serialized values that do not identify themselves as a -// MachinePublic, allowing other code to potentially parse it back in -// as the wrong key type. For new uses that don't require this -// specific raw byte serialization, please use -// MarshalText/UnmarshalText. -func (k MachinePublic) UntypedBytes() []byte { - return bytes.Clone(k.k[:]) -} - -// String returns the output of MarshalText as a string. -func (k MachinePublic) String() string { - bs, err := k.MarshalText() - if err != nil { - panic(err) - } - return string(bs) -} - -// AppendText implements encoding.TextAppender. -func (k MachinePublic) AppendText(b []byte) ([]byte, error) { - return appendHexKey(b, machinePublicHexPrefix, k.k[:]), nil -} - -// MarshalText implements encoding.TextMarshaler. -func (k MachinePublic) MarshalText() ([]byte, error) { - return k.AppendText(nil) -} - -// MarshalText implements encoding.TextUnmarshaler. -func (k *MachinePublic) UnmarshalText(b []byte) error { - return parseHex(k.k[:], mem.B(b), mem.S(machinePublicHexPrefix)) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "bytes" + "crypto/subtle" + "encoding/hex" + + "go4.org/mem" + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/nacl/box" + "tailscale.com/types/structs" +) + +const ( + // machinePrivateHexPrefix is the prefix used to identify a + // hex-encoded machine private key. + // + // This prefix name is a little unfortunate, in that it comes from + // WireGuard's own key types. Unfortunately we're stuck with it for + // machine keys, because we serialize them to disk with this prefix. + machinePrivateHexPrefix = "privkey:" + + // machinePublicHexPrefix is the prefix used to identify a + // hex-encoded machine public key. + // + // This prefix is used in the control protocol, so cannot be + // changed. + machinePublicHexPrefix = "mkey:" +) + +// MachinePrivate is a machine key, used for communication with the +// Tailscale coordination server. +type MachinePrivate struct { + _ structs.Incomparable // == isn't constant-time + k [32]byte +} + +// NewMachine creates and returns a new machine private key. +func NewMachine() MachinePrivate { + var ret MachinePrivate + rand(ret.k[:]) + clamp25519Private(ret.k[:]) + return ret +} + +// IsZero reports whether k is the zero value. +func (k MachinePrivate) IsZero() bool { + return k.Equal(MachinePrivate{}) +} + +// Equal reports whether k and other are the same key. +func (k MachinePrivate) Equal(other MachinePrivate) bool { + return subtle.ConstantTimeCompare(k.k[:], other.k[:]) == 1 +} + +// Public returns the MachinePublic for k. +// Panics if MachinePrivate is zero. +func (k MachinePrivate) Public() MachinePublic { + if k.IsZero() { + panic("can't take the public key of a zero MachinePrivate") + } + var ret MachinePublic + curve25519.ScalarBaseMult(&ret.k, &k.k) + return ret +} + +// AppendText implements encoding.TextAppender. +func (k MachinePrivate) AppendText(b []byte) ([]byte, error) { + return appendHexKey(b, machinePrivateHexPrefix, k.k[:]), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (k MachinePrivate) MarshalText() ([]byte, error) { + return k.AppendText(nil) +} + +// MarshalText implements encoding.TextUnmarshaler. +func (k *MachinePrivate) UnmarshalText(b []byte) error { + return parseHex(k.k[:], mem.B(b), mem.S(machinePrivateHexPrefix)) +} + +// UntypedBytes returns k, encoded as an untyped 64-character hex +// string. +// +// Deprecated: this function is risky to use, because it produces +// serialized values that do not identify themselves as a +// MachinePrivate, allowing other code to potentially parse it back in +// as the wrong key type. For new uses that don't require this +// specific raw byte serialization, please use +// MarshalText/UnmarshalText. +func (k MachinePrivate) UntypedBytes() []byte { + return bytes.Clone(k.k[:]) +} + +// SealTo wraps cleartext into a NaCl box (see +// golang.org/x/crypto/nacl) to p, authenticated from k, using a +// random nonce. +// +// The returned ciphertext is a 24-byte nonce concatenated with the +// box value. +func (k MachinePrivate) SealTo(p MachinePublic, cleartext []byte) (ciphertext []byte) { + if k.IsZero() || p.IsZero() { + panic("can't seal with zero keys") + } + var nonce [24]byte + rand(nonce[:]) + return box.Seal(nonce[:], cleartext, &nonce, &p.k, &k.k) +} + +// SharedKey returns the precomputed Nacl box shared key between k and p. +func (k MachinePrivate) SharedKey(p MachinePublic) MachinePrecomputedSharedKey { + var shared MachinePrecomputedSharedKey + box.Precompute(&shared.k, &p.k, &k.k) + return shared +} + +// MachinePrecomputedSharedKey is a precomputed shared NaCl box shared key. +type MachinePrecomputedSharedKey struct { + k [32]byte +} + +// Seal wraps cleartext into a NaCl box (see +// golang.org/x/crypto/nacl) using the shared key k as generated +// by MachinePrivate.SharedKey. +// +// The returned ciphertext is a 24-byte nonce concatenated with the +// box value. +func (k MachinePrecomputedSharedKey) Seal(cleartext []byte) (ciphertext []byte) { + if k == (MachinePrecomputedSharedKey{}) { + panic("can't seal with zero keys") + } + var nonce [24]byte + rand(nonce[:]) + return box.SealAfterPrecomputation(nonce[:], cleartext, &nonce, &k.k) +} + +// Open opens the NaCl box ciphertext, which must be a value created by +// MachinePrecomputedSharedKey.Seal or MachinePrivate.SealTo, and returns the +// inner cleartext if ciphertext is a valid box for the shared key k. +func (k MachinePrecomputedSharedKey) Open(ciphertext []byte) (cleartext []byte, ok bool) { + if k == (MachinePrecomputedSharedKey{}) { + panic("can't open with zero keys") + } + if len(ciphertext) < 24 { + return nil, false + } + var nonce [24]byte + copy(nonce[:], ciphertext) + return box.OpenAfterPrecomputation(nil, ciphertext[len(nonce):], &nonce, &k.k) +} + +// OpenFrom opens the NaCl box ciphertext, which must be a value +// created by SealTo, and returns the inner cleartext if ciphertext is +// a valid box from p to k. +func (k MachinePrivate) OpenFrom(p MachinePublic, ciphertext []byte) (cleartext []byte, ok bool) { + if k.IsZero() || p.IsZero() { + panic("can't open with zero keys") + } + if len(ciphertext) < 24 { + return nil, false + } + var nonce [24]byte + copy(nonce[:], ciphertext) + return box.Open(nil, ciphertext[len(nonce):], &nonce, &p.k, &k.k) +} + +// MachinePublic is the public portion of a a MachinePrivate. +type MachinePublic struct { + k [32]byte +} + +// MachinePublicFromRaw32 parses a 32-byte raw value as a MachinePublic. +// +// This should be used only when deserializing a MachinePublic from a +// binary protocol. +func MachinePublicFromRaw32(raw mem.RO) MachinePublic { + if raw.Len() != 32 { + panic("input has wrong size") + } + var ret MachinePublic + raw.Copy(ret.k[:]) + return ret +} + +// ParseMachinePublicUntyped parses an untyped 64-character hex value +// as a MachinePublic. +// +// Deprecated: this function is risky to use, because it cannot verify +// that the hex string was intended to be a MachinePublic. This can +// lead to accidentally decoding one type of key as another. For new +// uses that don't require backwards compatibility with the untyped +// string format, please use MarshalText/UnmarshalText. +func ParseMachinePublicUntyped(raw mem.RO) (MachinePublic, error) { + var ret MachinePublic + if err := parseHex(ret.k[:], raw, mem.B(nil)); err != nil { + return MachinePublic{}, err + } + return ret, nil +} + +// IsZero reports whether k is the zero value. +func (k MachinePublic) IsZero() bool { + return k == MachinePublic{} +} + +// ShortString returns the Tailscale conventional debug representation +// of a public key: the first five base64 digits of the key, in square +// brackets. +func (k MachinePublic) ShortString() string { + return debug32(k.k) +} + +// UntypedHexString returns k, encoded as an untyped 64-character hex +// string. +// +// Deprecated: this function is risky to use, because it produces +// serialized values that do not identify themselves as a +// MachinePublic, allowing other code to potentially parse it back in +// as the wrong key type. For new uses that don't require backwards +// compatibility with the untyped string format, please use +// MarshalText/UnmarshalText. +func (k MachinePublic) UntypedHexString() string { + return hex.EncodeToString(k.k[:]) +} + +// UntypedBytes returns k, encoded as an untyped 64-character hex +// string. +// +// Deprecated: this function is risky to use, because it produces +// serialized values that do not identify themselves as a +// MachinePublic, allowing other code to potentially parse it back in +// as the wrong key type. For new uses that don't require this +// specific raw byte serialization, please use +// MarshalText/UnmarshalText. +func (k MachinePublic) UntypedBytes() []byte { + return bytes.Clone(k.k[:]) +} + +// String returns the output of MarshalText as a string. +func (k MachinePublic) String() string { + bs, err := k.MarshalText() + if err != nil { + panic(err) + } + return string(bs) +} + +// AppendText implements encoding.TextAppender. +func (k MachinePublic) AppendText(b []byte) ([]byte, error) { + return appendHexKey(b, machinePublicHexPrefix, k.k[:]), nil +} + +// MarshalText implements encoding.TextMarshaler. +func (k MachinePublic) MarshalText() ([]byte, error) { + return k.AppendText(nil) +} + +// MarshalText implements encoding.TextUnmarshaler. +func (k *MachinePublic) UnmarshalText(b []byte) error { + return parseHex(k.k[:], mem.B(b), mem.S(machinePublicHexPrefix)) +} diff --git a/types/key/machine_test.go b/types/key/machine_test.go index f797ff087..157df9e43 100644 --- a/types/key/machine_test.go +++ b/types/key/machine_test.go @@ -1,119 +1,119 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import ( - "bytes" - "encoding/json" - "strings" - "testing" -) - -func TestMachineKey(t *testing.T) { - k := NewMachine() - if k.IsZero() { - t.Fatal("MachinePrivate should not be zero") - } - - p := k.Public() - if p.IsZero() { - t.Fatal("MachinePublic should not be zero") - } - - bs, err := p.MarshalText() - if err != nil { - t.Fatal(err) - } - if full, got := string(bs), ":"+p.UntypedHexString(); !strings.HasSuffix(full, got) { - t.Fatalf("MachinePublic.UntypedHexString is not a suffix of the typed serialization, got %q want suffix of %q", got, full) - } - - z := MachinePublic{} - if !z.IsZero() { - t.Fatal("IsZero(MachinePublic{}) is false") - } - if s := z.ShortString(); s != "" { - t.Fatalf("MachinePublic{}.ShortString() is %q, want \"\"", s) - } -} - -func TestMachineSerialization(t *testing.T) { - serialized := `{ - "Priv": "privkey:40ab1b58e9076c7a4d9d07291f5edf9d1aa017eb949624ba683317f48a640369", - "Pub":"mkey:50d20b455ecf12bc453f83c2cfdb2a24925d06cf2598dcaa54e91af82ce9f765" - }` - - // Carefully check that the expected serialized data decodes and - // reencodes to the expected keys. These types are serialized to - // disk all over the place and need to be stable. - priv := MachinePrivate{ - k: [32]uint8{ - 0x40, 0xab, 0x1b, 0x58, 0xe9, 0x7, 0x6c, 0x7a, 0x4d, 0x9d, 0x7, - 0x29, 0x1f, 0x5e, 0xdf, 0x9d, 0x1a, 0xa0, 0x17, 0xeb, 0x94, - 0x96, 0x24, 0xba, 0x68, 0x33, 0x17, 0xf4, 0x8a, 0x64, 0x3, 0x69, - }, - } - pub := MachinePublic{ - k: [32]uint8{ - 0x50, 0xd2, 0xb, 0x45, 0x5e, 0xcf, 0x12, 0xbc, 0x45, 0x3f, 0x83, - 0xc2, 0xcf, 0xdb, 0x2a, 0x24, 0x92, 0x5d, 0x6, 0xcf, 0x25, 0x98, - 0xdc, 0xaa, 0x54, 0xe9, 0x1a, 0xf8, 0x2c, 0xe9, 0xf7, 0x65, - }, - } - - type keypair struct { - Priv MachinePrivate - Pub MachinePublic - } - - var a keypair - if err := json.Unmarshal([]byte(serialized), &a); err != nil { - t.Fatal(err) - } - if !a.Priv.Equal(priv) { - t.Errorf("wrong deserialization of private key, got %#v want %#v", a.Priv, priv) - } - if a.Pub != pub { - t.Errorf("wrong deserialization of public key, got %#v want %#v", a.Pub, pub) - } - - bs, err := json.MarshalIndent(a, "", " ") - if err != nil { - t.Fatal(err) - } - - var b bytes.Buffer - json.Indent(&b, []byte(serialized), "", " ") - if got, want := string(bs), b.String(); got != want { - t.Error("json serialization doesn't roundtrip") - } -} - -func TestSealViaSharedKey(t *testing.T) { - // encrypt a message from a to b - a := NewMachine() - b := NewMachine() - apub, bpub := a.Public(), b.Public() - - shared := a.SharedKey(bpub) - - const clear = "the eagle flies at midnight" - enc := shared.Seal([]byte(clear)) - - back, ok := b.OpenFrom(apub, enc) - if !ok { - t.Fatal("failed to decrypt") - } - if string(back) != clear { - t.Errorf("OpenFrom got %q; want cleartext %q", back, clear) - } - - backShared, ok := shared.Open(enc) - if !ok { - t.Fatal("failed to decrypt from shared key") - } - if string(backShared) != clear { - t.Errorf("Open got %q; want cleartext %q", back, clear) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "bytes" + "encoding/json" + "strings" + "testing" +) + +func TestMachineKey(t *testing.T) { + k := NewMachine() + if k.IsZero() { + t.Fatal("MachinePrivate should not be zero") + } + + p := k.Public() + if p.IsZero() { + t.Fatal("MachinePublic should not be zero") + } + + bs, err := p.MarshalText() + if err != nil { + t.Fatal(err) + } + if full, got := string(bs), ":"+p.UntypedHexString(); !strings.HasSuffix(full, got) { + t.Fatalf("MachinePublic.UntypedHexString is not a suffix of the typed serialization, got %q want suffix of %q", got, full) + } + + z := MachinePublic{} + if !z.IsZero() { + t.Fatal("IsZero(MachinePublic{}) is false") + } + if s := z.ShortString(); s != "" { + t.Fatalf("MachinePublic{}.ShortString() is %q, want \"\"", s) + } +} + +func TestMachineSerialization(t *testing.T) { + serialized := `{ + "Priv": "privkey:40ab1b58e9076c7a4d9d07291f5edf9d1aa017eb949624ba683317f48a640369", + "Pub":"mkey:50d20b455ecf12bc453f83c2cfdb2a24925d06cf2598dcaa54e91af82ce9f765" + }` + + // Carefully check that the expected serialized data decodes and + // reencodes to the expected keys. These types are serialized to + // disk all over the place and need to be stable. + priv := MachinePrivate{ + k: [32]uint8{ + 0x40, 0xab, 0x1b, 0x58, 0xe9, 0x7, 0x6c, 0x7a, 0x4d, 0x9d, 0x7, + 0x29, 0x1f, 0x5e, 0xdf, 0x9d, 0x1a, 0xa0, 0x17, 0xeb, 0x94, + 0x96, 0x24, 0xba, 0x68, 0x33, 0x17, 0xf4, 0x8a, 0x64, 0x3, 0x69, + }, + } + pub := MachinePublic{ + k: [32]uint8{ + 0x50, 0xd2, 0xb, 0x45, 0x5e, 0xcf, 0x12, 0xbc, 0x45, 0x3f, 0x83, + 0xc2, 0xcf, 0xdb, 0x2a, 0x24, 0x92, 0x5d, 0x6, 0xcf, 0x25, 0x98, + 0xdc, 0xaa, 0x54, 0xe9, 0x1a, 0xf8, 0x2c, 0xe9, 0xf7, 0x65, + }, + } + + type keypair struct { + Priv MachinePrivate + Pub MachinePublic + } + + var a keypair + if err := json.Unmarshal([]byte(serialized), &a); err != nil { + t.Fatal(err) + } + if !a.Priv.Equal(priv) { + t.Errorf("wrong deserialization of private key, got %#v want %#v", a.Priv, priv) + } + if a.Pub != pub { + t.Errorf("wrong deserialization of public key, got %#v want %#v", a.Pub, pub) + } + + bs, err := json.MarshalIndent(a, "", " ") + if err != nil { + t.Fatal(err) + } + + var b bytes.Buffer + json.Indent(&b, []byte(serialized), "", " ") + if got, want := string(bs), b.String(); got != want { + t.Error("json serialization doesn't roundtrip") + } +} + +func TestSealViaSharedKey(t *testing.T) { + // encrypt a message from a to b + a := NewMachine() + b := NewMachine() + apub, bpub := a.Public(), b.Public() + + shared := a.SharedKey(bpub) + + const clear = "the eagle flies at midnight" + enc := shared.Seal([]byte(clear)) + + back, ok := b.OpenFrom(apub, enc) + if !ok { + t.Fatal("failed to decrypt") + } + if string(back) != clear { + t.Errorf("OpenFrom got %q; want cleartext %q", back, clear) + } + + backShared, ok := shared.Open(enc) + if !ok { + t.Fatal("failed to decrypt from shared key") + } + if string(backShared) != clear { + t.Errorf("Open got %q; want cleartext %q", back, clear) + } +} diff --git a/types/key/nl_test.go b/types/key/nl_test.go index 2e10d04ac..75b7765a1 100644 --- a/types/key/nl_test.go +++ b/types/key/nl_test.go @@ -1,48 +1,48 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package key - -import ( - "bytes" - "testing" -) - -func TestNLPrivate(t *testing.T) { - p := NewNLPrivate() - - encoded, err := p.MarshalText() - if err != nil { - t.Fatal(err) - } - var decoded NLPrivate - if err := decoded.UnmarshalText(encoded); err != nil { - t.Fatal(err) - } - if !bytes.Equal(decoded.k[:], p.k[:]) { - t.Error("decoded and generated NLPrivate bytes differ") - } - - // Test NLPublic - pub := p.Public() - encoded, err = pub.MarshalText() - if err != nil { - t.Fatal(err) - } - var decodedPub NLPublic - if err := decodedPub.UnmarshalText(encoded); err != nil { - t.Fatal(err) - } - if !bytes.Equal(decodedPub.k[:], pub.k[:]) { - t.Error("decoded and generated NLPublic bytes differ") - } - - // Test decoding with CLI prefix: 'nlpub:' => 'tlpub:' - decodedPub = NLPublic{} - if err := decodedPub.UnmarshalText([]byte(pub.CLIString())); err != nil { - t.Fatal(err) - } - if !bytes.Equal(decodedPub.k[:], pub.k[:]) { - t.Error("decoded and generated NLPublic bytes differ (CLI prefix)") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "bytes" + "testing" +) + +func TestNLPrivate(t *testing.T) { + p := NewNLPrivate() + + encoded, err := p.MarshalText() + if err != nil { + t.Fatal(err) + } + var decoded NLPrivate + if err := decoded.UnmarshalText(encoded); err != nil { + t.Fatal(err) + } + if !bytes.Equal(decoded.k[:], p.k[:]) { + t.Error("decoded and generated NLPrivate bytes differ") + } + + // Test NLPublic + pub := p.Public() + encoded, err = pub.MarshalText() + if err != nil { + t.Fatal(err) + } + var decodedPub NLPublic + if err := decodedPub.UnmarshalText(encoded); err != nil { + t.Fatal(err) + } + if !bytes.Equal(decodedPub.k[:], pub.k[:]) { + t.Error("decoded and generated NLPublic bytes differ") + } + + // Test decoding with CLI prefix: 'nlpub:' => 'tlpub:' + decodedPub = NLPublic{} + if err := decodedPub.UnmarshalText([]byte(pub.CLIString())); err != nil { + t.Fatal(err) + } + if !bytes.Equal(decodedPub.k[:], pub.k[:]) { + t.Error("decoded and generated NLPublic bytes differ (CLI prefix)") + } +} diff --git a/types/lazy/unsync.go b/types/lazy/unsync.go index ca46f9c7b..0f89ce4f6 100644 --- a/types/lazy/unsync.go +++ b/types/lazy/unsync.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package lazy - -// GValue is a lazily computed value. -// -// Use either Get or GetErr, depending on whether your fill function returns an -// error. -// -// Recursive use of a GValue from its own fill function will panic. -// -// GValue is not safe for concurrent use. (Mnemonic: G is for one Goroutine, -// which isn't strictly true if you provide your own synchronization between -// goroutines, but in practice most of our callers have been using it within -// a single goroutine.) -type GValue[T any] struct { - done bool - calling bool - V T - err error -} - -// Set attempts to set z's value to val, and reports whether it succeeded. -// Set only succeeds if none of Get/GetErr/Set have been called before. -func (z *GValue[T]) Set(v T) bool { - if z.done { - return false - } - if z.calling { - panic("Set while Get fill is running") - } - z.V = v - z.done = true - return true -} - -// MustSet sets z's value to val, or panics if z already has a value. -func (z *GValue[T]) MustSet(val T) { - if !z.Set(val) { - panic("Set after already filled") - } -} - -// Get returns z's value, calling fill to compute it if necessary. -// f is called at most once. -func (z *GValue[T]) Get(fill func() T) T { - if !z.done { - if z.calling { - panic("recursive lazy fill") - } - z.calling = true - z.V = fill() - z.done = true - z.calling = false - } - return z.V -} - -// GetErr returns z's value, calling fill to compute it if necessary. -// f is called at most once, and z remembers both of fill's outputs. -func (z *GValue[T]) GetErr(fill func() (T, error)) (T, error) { - if !z.done { - if z.calling { - panic("recursive lazy fill") - } - z.calling = true - z.V, z.err = fill() - z.done = true - z.calling = false - } - return z.V, z.err -} - -// GFunc wraps a function to make it lazy. -// -// The returned function calls fill the first time it's called, and returns -// fill's result on every subsequent call. -// -// The returned function is not safe for concurrent use. -func GFunc[T any](fill func() T) func() T { - var v GValue[T] - return func() T { - return v.Get(fill) - } -} - -// SyncFuncErr wraps a function to make it lazy. -// -// The returned function calls fill the first time it's called, and returns -// fill's results on every subsequent call. -// -// The returned function is not safe for concurrent use. -func GFuncErr[T any](fill func() (T, error)) func() (T, error) { - var v GValue[T] - return func() (T, error) { - return v.GetErr(fill) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lazy + +// GValue is a lazily computed value. +// +// Use either Get or GetErr, depending on whether your fill function returns an +// error. +// +// Recursive use of a GValue from its own fill function will panic. +// +// GValue is not safe for concurrent use. (Mnemonic: G is for one Goroutine, +// which isn't strictly true if you provide your own synchronization between +// goroutines, but in practice most of our callers have been using it within +// a single goroutine.) +type GValue[T any] struct { + done bool + calling bool + V T + err error +} + +// Set attempts to set z's value to val, and reports whether it succeeded. +// Set only succeeds if none of Get/GetErr/Set have been called before. +func (z *GValue[T]) Set(v T) bool { + if z.done { + return false + } + if z.calling { + panic("Set while Get fill is running") + } + z.V = v + z.done = true + return true +} + +// MustSet sets z's value to val, or panics if z already has a value. +func (z *GValue[T]) MustSet(val T) { + if !z.Set(val) { + panic("Set after already filled") + } +} + +// Get returns z's value, calling fill to compute it if necessary. +// f is called at most once. +func (z *GValue[T]) Get(fill func() T) T { + if !z.done { + if z.calling { + panic("recursive lazy fill") + } + z.calling = true + z.V = fill() + z.done = true + z.calling = false + } + return z.V +} + +// GetErr returns z's value, calling fill to compute it if necessary. +// f is called at most once, and z remembers both of fill's outputs. +func (z *GValue[T]) GetErr(fill func() (T, error)) (T, error) { + if !z.done { + if z.calling { + panic("recursive lazy fill") + } + z.calling = true + z.V, z.err = fill() + z.done = true + z.calling = false + } + return z.V, z.err +} + +// GFunc wraps a function to make it lazy. +// +// The returned function calls fill the first time it's called, and returns +// fill's result on every subsequent call. +// +// The returned function is not safe for concurrent use. +func GFunc[T any](fill func() T) func() T { + var v GValue[T] + return func() T { + return v.Get(fill) + } +} + +// SyncFuncErr wraps a function to make it lazy. +// +// The returned function calls fill the first time it's called, and returns +// fill's results on every subsequent call. +// +// The returned function is not safe for concurrent use. +func GFuncErr[T any](fill func() (T, error)) func() (T, error) { + var v GValue[T] + return func() (T, error) { + return v.GetErr(fill) + } +} diff --git a/types/lazy/unsync_test.go b/types/lazy/unsync_test.go index d8b870dbe..f0d2494d1 100644 --- a/types/lazy/unsync_test.go +++ b/types/lazy/unsync_test.go @@ -1,140 +1,140 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package lazy - -import ( - "errors" - "testing" -) - -func fortyTwo() int { return 42 } - -func TestGValue(t *testing.T) { - var lt GValue[int] - n := int(testing.AllocsPerRun(1000, func() { - got := lt.Get(fortyTwo) - if got != 42 { - t.Fatalf("got %v; want 42", got) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} - -func TestGValueErr(t *testing.T) { - var lt GValue[int] - n := int(testing.AllocsPerRun(1000, func() { - got, err := lt.GetErr(func() (int, error) { - return 42, nil - }) - if got != 42 || err != nil { - t.Fatalf("got %v, %v; want 42, nil", got, err) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } - - var lterr GValue[int] - wantErr := errors.New("test error") - n = int(testing.AllocsPerRun(1000, func() { - got, err := lterr.GetErr(func() (int, error) { - return 0, wantErr - }) - if got != 0 || err != wantErr { - t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} - -func TestGValueSet(t *testing.T) { - var lt GValue[int] - if !lt.Set(42) { - t.Fatalf("Set failed") - } - if lt.Set(43) { - t.Fatalf("Set succeeded after first Set") - } - n := int(testing.AllocsPerRun(1000, func() { - got := lt.Get(fortyTwo) - if got != 42 { - t.Fatalf("got %v; want 42", got) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} - -func TestGValueMustSet(t *testing.T) { - var lt GValue[int] - lt.MustSet(42) - defer func() { - if e := recover(); e == nil { - t.Errorf("unexpected success; want panic") - } - }() - lt.MustSet(43) -} - -func TestGValueRecursivePanic(t *testing.T) { - defer func() { - if e := recover(); e != nil { - t.Logf("got panic, as expected") - } else { - t.Errorf("unexpected success; want panic") - } - }() - v := GValue[int]{} - v.Get(func() int { - return v.Get(func() int { return 42 }) - }) -} - -func TestGFunc(t *testing.T) { - f := GFunc(fortyTwo) - - n := int(testing.AllocsPerRun(1000, func() { - got := f() - if got != 42 { - t.Fatalf("got %v; want 42", got) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} - -func TestGFuncErr(t *testing.T) { - f := GFuncErr(func() (int, error) { - return 42, nil - }) - n := int(testing.AllocsPerRun(1000, func() { - got, err := f() - if got != 42 || err != nil { - t.Fatalf("got %v, %v; want 42, nil", got, err) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } - - wantErr := errors.New("test error") - f = GFuncErr(func() (int, error) { - return 0, wantErr - }) - n = int(testing.AllocsPerRun(1000, func() { - got, err := f() - if got != 0 || err != wantErr { - t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lazy + +import ( + "errors" + "testing" +) + +func fortyTwo() int { return 42 } + +func TestGValue(t *testing.T) { + var lt GValue[int] + n := int(testing.AllocsPerRun(1000, func() { + got := lt.Get(fortyTwo) + if got != 42 { + t.Fatalf("got %v; want 42", got) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGValueErr(t *testing.T) { + var lt GValue[int] + n := int(testing.AllocsPerRun(1000, func() { + got, err := lt.GetErr(func() (int, error) { + return 42, nil + }) + if got != 42 || err != nil { + t.Fatalf("got %v, %v; want 42, nil", got, err) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } + + var lterr GValue[int] + wantErr := errors.New("test error") + n = int(testing.AllocsPerRun(1000, func() { + got, err := lterr.GetErr(func() (int, error) { + return 0, wantErr + }) + if got != 0 || err != wantErr { + t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGValueSet(t *testing.T) { + var lt GValue[int] + if !lt.Set(42) { + t.Fatalf("Set failed") + } + if lt.Set(43) { + t.Fatalf("Set succeeded after first Set") + } + n := int(testing.AllocsPerRun(1000, func() { + got := lt.Get(fortyTwo) + if got != 42 { + t.Fatalf("got %v; want 42", got) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGValueMustSet(t *testing.T) { + var lt GValue[int] + lt.MustSet(42) + defer func() { + if e := recover(); e == nil { + t.Errorf("unexpected success; want panic") + } + }() + lt.MustSet(43) +} + +func TestGValueRecursivePanic(t *testing.T) { + defer func() { + if e := recover(); e != nil { + t.Logf("got panic, as expected") + } else { + t.Errorf("unexpected success; want panic") + } + }() + v := GValue[int]{} + v.Get(func() int { + return v.Get(func() int { return 42 }) + }) +} + +func TestGFunc(t *testing.T) { + f := GFunc(fortyTwo) + + n := int(testing.AllocsPerRun(1000, func() { + got := f() + if got != 42 { + t.Fatalf("got %v; want 42", got) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGFuncErr(t *testing.T) { + f := GFuncErr(func() (int, error) { + return 42, nil + }) + n := int(testing.AllocsPerRun(1000, func() { + got, err := f() + if got != 42 || err != nil { + t.Fatalf("got %v, %v; want 42, nil", got, err) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } + + wantErr := errors.New("test error") + f = GFuncErr(func() (int, error) { + return 0, wantErr + }) + n = int(testing.AllocsPerRun(1000, func() { + got, err := f() + if got != 0 || err != wantErr { + t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} diff --git a/types/logger/rusage.go b/types/logger/rusage.go index ebe0e972d..3943636d6 100644 --- a/types/logger/rusage.go +++ b/types/logger/rusage.go @@ -1,23 +1,23 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package logger - -import ( - "fmt" - "runtime" -) - -// RusagePrefixLog returns a Logf func wrapping the provided logf func that adds -// a prefixed log message to each line with the current binary memory usage -// and max RSS. -func RusagePrefixLog(logf Logf) Logf { - return func(f string, argv ...any) { - var m runtime.MemStats - runtime.ReadMemStats(&m) - goMem := float64(m.HeapInuse+m.StackInuse) / (1 << 20) - maxRSS := rusageMaxRSS() - pf := fmt.Sprintf("%.1fM/%.1fM %s", goMem, maxRSS, f) - logf(pf, argv...) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package logger + +import ( + "fmt" + "runtime" +) + +// RusagePrefixLog returns a Logf func wrapping the provided logf func that adds +// a prefixed log message to each line with the current binary memory usage +// and max RSS. +func RusagePrefixLog(logf Logf) Logf { + return func(f string, argv ...any) { + var m runtime.MemStats + runtime.ReadMemStats(&m) + goMem := float64(m.HeapInuse+m.StackInuse) / (1 << 20) + maxRSS := rusageMaxRSS() + pf := fmt.Sprintf("%.1fM/%.1fM %s", goMem, maxRSS, f) + logf(pf, argv...) + } +} diff --git a/types/logger/rusage_stub.go b/types/logger/rusage_stub.go index a228b0865..f646f1e1e 100644 --- a/types/logger/rusage_stub.go +++ b/types/logger/rusage_stub.go @@ -1,11 +1,11 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build windows || wasm || plan9 || tamago - -package logger - -func rusageMaxRSS() float64 { - // TODO(apenwarr): Substitute Windows equivalent of Getrusage() here. - return 0 -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows || wasm || plan9 || tamago + +package logger + +func rusageMaxRSS() float64 { + // TODO(apenwarr): Substitute Windows equivalent of Getrusage() here. + return 0 +} diff --git a/types/logger/rusage_syscall.go b/types/logger/rusage_syscall.go index 19488aef1..2871b66c6 100644 --- a/types/logger/rusage_syscall.go +++ b/types/logger/rusage_syscall.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !wasm && !plan9 && !tamago - -package logger - -import ( - "runtime" - - "golang.org/x/sys/unix" -) - -func rusageMaxRSS() float64 { - var ru unix.Rusage - err := unix.Getrusage(unix.RUSAGE_SELF, &ru) - if err != nil { - return 0 - } - - rss := float64(ru.Maxrss) - if runtime.GOOS == "darwin" || runtime.GOOS == "ios" { - rss /= 1 << 20 // ru_maxrss is bytes on darwin - } else { - // ru_maxrss is kilobytes elsewhere (linux, openbsd, etc) - rss /= 1 << 10 - } - return rss -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !wasm && !plan9 && !tamago + +package logger + +import ( + "runtime" + + "golang.org/x/sys/unix" +) + +func rusageMaxRSS() float64 { + var ru unix.Rusage + err := unix.Getrusage(unix.RUSAGE_SELF, &ru) + if err != nil { + return 0 + } + + rss := float64(ru.Maxrss) + if runtime.GOOS == "darwin" || runtime.GOOS == "ios" { + rss /= 1 << 20 // ru_maxrss is bytes on darwin + } else { + // ru_maxrss is kilobytes elsewhere (linux, openbsd, etc) + rss /= 1 << 10 + } + return rss +} diff --git a/types/logger/tokenbucket.go b/types/logger/tokenbucket.go index 2407e01a7..83d4059c2 100644 --- a/types/logger/tokenbucket.go +++ b/types/logger/tokenbucket.go @@ -1,63 +1,63 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package logger - -import ( - "time" -) - -// tokenBucket is a simple token bucket style rate limiter. - -// It's similar in function to golang.org/x/time/rate.Limiter, which we -// can't use because: -// - It doesn't give access to the number of accumulated tokens, which we -// need for implementing hysteresis; -// - It doesn't let us provide our own time function, which we need for -// implementing proper unit tests. -// -// rate.Limiter is also much more complex than necessary, but that wouldn't -// be enough to disqualify it on its own. -// -// Unlike rate.Limiter, this token bucket does not attempt to -// do any locking of its own. Don't try to access it reentrantly. -// That's fine inside this types/logger package because we already have -// locking at a higher level. -type tokenBucket struct { - remaining int - max int - tick time.Duration - t time.Time -} - -func newTokenBucket(tick time.Duration, max int, now time.Time) *tokenBucket { - return &tokenBucket{max, max, tick, now} -} - -func (tb *tokenBucket) Get() bool { - if tb.remaining > 0 { - tb.remaining-- - return true - } - return false -} - -func (tb *tokenBucket) Refund(n int) { - b := tb.remaining + n - if b > tb.max { - tb.remaining = tb.max - } else { - tb.remaining = b - } -} - -func (tb *tokenBucket) AdvanceTo(t time.Time) { - diff := t.Sub(tb.t) - - // only use up whole ticks. The remainder will be used up - // next time. - ticks := int(diff / tb.tick) - tb.t = tb.t.Add(time.Duration(ticks) * tb.tick) - - tb.Refund(ticks) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package logger + +import ( + "time" +) + +// tokenBucket is a simple token bucket style rate limiter. + +// It's similar in function to golang.org/x/time/rate.Limiter, which we +// can't use because: +// - It doesn't give access to the number of accumulated tokens, which we +// need for implementing hysteresis; +// - It doesn't let us provide our own time function, which we need for +// implementing proper unit tests. +// +// rate.Limiter is also much more complex than necessary, but that wouldn't +// be enough to disqualify it on its own. +// +// Unlike rate.Limiter, this token bucket does not attempt to +// do any locking of its own. Don't try to access it reentrantly. +// That's fine inside this types/logger package because we already have +// locking at a higher level. +type tokenBucket struct { + remaining int + max int + tick time.Duration + t time.Time +} + +func newTokenBucket(tick time.Duration, max int, now time.Time) *tokenBucket { + return &tokenBucket{max, max, tick, now} +} + +func (tb *tokenBucket) Get() bool { + if tb.remaining > 0 { + tb.remaining-- + return true + } + return false +} + +func (tb *tokenBucket) Refund(n int) { + b := tb.remaining + n + if b > tb.max { + tb.remaining = tb.max + } else { + tb.remaining = b + } +} + +func (tb *tokenBucket) AdvanceTo(t time.Time) { + diff := t.Sub(tb.t) + + // only use up whole ticks. The remainder will be used up + // next time. + ticks := int(diff / tb.tick) + tb.t = tb.t.Add(time.Duration(ticks) * tb.tick) + + tb.Refund(ticks) +} diff --git a/types/netlogtype/netlogtype.go b/types/netlogtype/netlogtype.go index 56002628e..f2fa2bda9 100644 --- a/types/netlogtype/netlogtype.go +++ b/types/netlogtype/netlogtype.go @@ -1,100 +1,100 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netlogtype defines types for network logging. -package netlogtype - -import ( - "net/netip" - "time" - - "tailscale.com/tailcfg" - "tailscale.com/types/ipproto" -) - -// TODO(joetsai): Remove "omitempty" if "omitzero" is ever supported in both -// the v1 and v2 "json" packages. - -// Message is the log message that captures network traffic. -type Message struct { - NodeID tailcfg.StableNodeID `json:"nodeId" cbor:"0,keyasint"` // e.g., "n123456CNTRL" - - Start time.Time `json:"start" cbor:"12,keyasint"` // inclusive - End time.Time `json:"end" cbor:"13,keyasint"` // inclusive - - VirtualTraffic []ConnectionCounts `json:"virtualTraffic,omitempty" cbor:"14,keyasint,omitempty"` - SubnetTraffic []ConnectionCounts `json:"subnetTraffic,omitempty" cbor:"15,keyasint,omitempty"` - ExitTraffic []ConnectionCounts `json:"exitTraffic,omitempty" cbor:"16,keyasint,omitempty"` - PhysicalTraffic []ConnectionCounts `json:"physicalTraffic,omitempty" cbor:"17,keyasint,omitempty"` -} - -const ( - messageJSON = `{"nodeId":"n0123456789abcdefCNTRL",` + maxJSONTimeRange + `,` + minJSONTraffic + `}` - maxJSONTimeRange = `"start":` + maxJSONRFC3339 + `,"end":` + maxJSONRFC3339 - maxJSONRFC3339 = `"0001-01-01T00:00:00.000000000Z"` - minJSONTraffic = `"virtualTraffic":{},"subnetTraffic":{},"exitTraffic":{},"physicalTraffic":{}` - - // MaxMessageJSONSize is the overhead size of Message when it is - // serialized as JSON assuming that each traffic map is populated. - MaxMessageJSONSize = len(messageJSON) - - maxJSONConnCounts = `{` + maxJSONConn + `,` + maxJSONCounts + `}` - maxJSONConn = `"proto":` + maxJSONProto + `,"src":` + maxJSONAddrPort + `,"dst":` + maxJSONAddrPort - maxJSONProto = `255` - maxJSONAddrPort = `"[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]:65535"` - maxJSONCounts = `"txPkts":` + maxJSONCount + `,"txBytes":` + maxJSONCount + `,"rxPkts":` + maxJSONCount + `,"rxBytes":` + maxJSONCount - maxJSONCount = `18446744073709551615` - - // MaxConnectionCountsJSONSize is the maximum size of a ConnectionCounts - // when it is serialized as JSON, assuming no superfluous whitespace. - // It does not include the trailing comma that often appears when - // this object is nested within an array. - // It assumes that netip.Addr never has IPv6 zones. - MaxConnectionCountsJSONSize = len(maxJSONConnCounts) - - maxCBORConnCounts = "\xbf" + maxCBORConn + maxCBORCounts + "\xff" - maxCBORConn = "\x00" + maxCBORProto + "\x01" + maxCBORAddrPort + "\x02" + maxCBORAddrPort - maxCBORProto = "\x18\xff" - maxCBORAddrPort = "\x52\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" - maxCBORCounts = "\x0c" + maxCBORCount + "\x0d" + maxCBORCount + "\x0e" + maxCBORCount + "\x0f" + maxCBORCount - maxCBORCount = "\x1b\xff\xff\xff\xff\xff\xff\xff\xff" - - // MaxConnectionCountsCBORSize is the maximum size of a ConnectionCounts - // when it is serialized as CBOR. - // It assumes that netip.Addr never has IPv6 zones. - MaxConnectionCountsCBORSize = len(maxCBORConnCounts) -) - -// ConnectionCounts is a flattened struct of both a connection and counts. -type ConnectionCounts struct { - Connection - Counts -} - -// Connection is a 5-tuple of proto, source and destination IP and port. -type Connection struct { - Proto ipproto.Proto `json:"proto,omitzero,omitempty" cbor:"0,keyasint,omitempty"` - Src netip.AddrPort `json:"src,omitzero,omitempty" cbor:"1,keyasint,omitempty"` - Dst netip.AddrPort `json:"dst,omitzero,omitempty" cbor:"2,keyasint,omitempty"` -} - -func (c Connection) IsZero() bool { return c == Connection{} } - -// Counts are statistics about a particular connection. -type Counts struct { - TxPackets uint64 `json:"txPkts,omitzero,omitempty" cbor:"12,keyasint,omitempty"` - TxBytes uint64 `json:"txBytes,omitzero,omitempty" cbor:"13,keyasint,omitempty"` - RxPackets uint64 `json:"rxPkts,omitzero,omitempty" cbor:"14,keyasint,omitempty"` - RxBytes uint64 `json:"rxBytes,omitzero,omitempty" cbor:"15,keyasint,omitempty"` -} - -func (c Counts) IsZero() bool { return c == Counts{} } - -// Add adds the counts from both c1 and c2. -func (c1 Counts) Add(c2 Counts) Counts { - c1.TxPackets += c2.TxPackets - c1.TxBytes += c2.TxBytes - c1.RxPackets += c2.RxPackets - c1.RxBytes += c2.RxBytes - return c1 -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netlogtype defines types for network logging. +package netlogtype + +import ( + "net/netip" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/types/ipproto" +) + +// TODO(joetsai): Remove "omitempty" if "omitzero" is ever supported in both +// the v1 and v2 "json" packages. + +// Message is the log message that captures network traffic. +type Message struct { + NodeID tailcfg.StableNodeID `json:"nodeId" cbor:"0,keyasint"` // e.g., "n123456CNTRL" + + Start time.Time `json:"start" cbor:"12,keyasint"` // inclusive + End time.Time `json:"end" cbor:"13,keyasint"` // inclusive + + VirtualTraffic []ConnectionCounts `json:"virtualTraffic,omitempty" cbor:"14,keyasint,omitempty"` + SubnetTraffic []ConnectionCounts `json:"subnetTraffic,omitempty" cbor:"15,keyasint,omitempty"` + ExitTraffic []ConnectionCounts `json:"exitTraffic,omitempty" cbor:"16,keyasint,omitempty"` + PhysicalTraffic []ConnectionCounts `json:"physicalTraffic,omitempty" cbor:"17,keyasint,omitempty"` +} + +const ( + messageJSON = `{"nodeId":"n0123456789abcdefCNTRL",` + maxJSONTimeRange + `,` + minJSONTraffic + `}` + maxJSONTimeRange = `"start":` + maxJSONRFC3339 + `,"end":` + maxJSONRFC3339 + maxJSONRFC3339 = `"0001-01-01T00:00:00.000000000Z"` + minJSONTraffic = `"virtualTraffic":{},"subnetTraffic":{},"exitTraffic":{},"physicalTraffic":{}` + + // MaxMessageJSONSize is the overhead size of Message when it is + // serialized as JSON assuming that each traffic map is populated. + MaxMessageJSONSize = len(messageJSON) + + maxJSONConnCounts = `{` + maxJSONConn + `,` + maxJSONCounts + `}` + maxJSONConn = `"proto":` + maxJSONProto + `,"src":` + maxJSONAddrPort + `,"dst":` + maxJSONAddrPort + maxJSONProto = `255` + maxJSONAddrPort = `"[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff]:65535"` + maxJSONCounts = `"txPkts":` + maxJSONCount + `,"txBytes":` + maxJSONCount + `,"rxPkts":` + maxJSONCount + `,"rxBytes":` + maxJSONCount + maxJSONCount = `18446744073709551615` + + // MaxConnectionCountsJSONSize is the maximum size of a ConnectionCounts + // when it is serialized as JSON, assuming no superfluous whitespace. + // It does not include the trailing comma that often appears when + // this object is nested within an array. + // It assumes that netip.Addr never has IPv6 zones. + MaxConnectionCountsJSONSize = len(maxJSONConnCounts) + + maxCBORConnCounts = "\xbf" + maxCBORConn + maxCBORCounts + "\xff" + maxCBORConn = "\x00" + maxCBORProto + "\x01" + maxCBORAddrPort + "\x02" + maxCBORAddrPort + maxCBORProto = "\x18\xff" + maxCBORAddrPort = "\x52\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + maxCBORCounts = "\x0c" + maxCBORCount + "\x0d" + maxCBORCount + "\x0e" + maxCBORCount + "\x0f" + maxCBORCount + maxCBORCount = "\x1b\xff\xff\xff\xff\xff\xff\xff\xff" + + // MaxConnectionCountsCBORSize is the maximum size of a ConnectionCounts + // when it is serialized as CBOR. + // It assumes that netip.Addr never has IPv6 zones. + MaxConnectionCountsCBORSize = len(maxCBORConnCounts) +) + +// ConnectionCounts is a flattened struct of both a connection and counts. +type ConnectionCounts struct { + Connection + Counts +} + +// Connection is a 5-tuple of proto, source and destination IP and port. +type Connection struct { + Proto ipproto.Proto `json:"proto,omitzero,omitempty" cbor:"0,keyasint,omitempty"` + Src netip.AddrPort `json:"src,omitzero,omitempty" cbor:"1,keyasint,omitempty"` + Dst netip.AddrPort `json:"dst,omitzero,omitempty" cbor:"2,keyasint,omitempty"` +} + +func (c Connection) IsZero() bool { return c == Connection{} } + +// Counts are statistics about a particular connection. +type Counts struct { + TxPackets uint64 `json:"txPkts,omitzero,omitempty" cbor:"12,keyasint,omitempty"` + TxBytes uint64 `json:"txBytes,omitzero,omitempty" cbor:"13,keyasint,omitempty"` + RxPackets uint64 `json:"rxPkts,omitzero,omitempty" cbor:"14,keyasint,omitempty"` + RxBytes uint64 `json:"rxBytes,omitzero,omitempty" cbor:"15,keyasint,omitempty"` +} + +func (c Counts) IsZero() bool { return c == Counts{} } + +// Add adds the counts from both c1 and c2. +func (c1 Counts) Add(c2 Counts) Counts { + c1.TxPackets += c2.TxPackets + c1.TxBytes += c2.TxBytes + c1.RxPackets += c2.RxPackets + c1.RxBytes += c2.RxBytes + return c1 +} diff --git a/types/netlogtype/netlogtype_test.go b/types/netlogtype/netlogtype_test.go index 1fa604b31..7f29090c5 100644 --- a/types/netlogtype/netlogtype_test.go +++ b/types/netlogtype/netlogtype_test.go @@ -1,39 +1,39 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netlogtype - -import ( - "encoding/json" - "math" - "net/netip" - "testing" - - "github.com/fxamacker/cbor/v2" - "github.com/google/go-cmp/cmp" - "tailscale.com/util/must" -) - -func TestMaxSize(t *testing.T) { - maxAddr := netip.AddrFrom16([16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}) - maxAddrPort := netip.AddrPortFrom(maxAddr, math.MaxUint16) - cc := ConnectionCounts{ - // NOTE: These composite literals are deliberately unkeyed so that - // added fields result in a build failure here. - // Newly added fields should result in an update to both - // MaxConnectionCountsJSONSize and MaxConnectionCountsCBORSize. - Connection{math.MaxUint8, maxAddrPort, maxAddrPort}, - Counts{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64}, - } - - outJSON := must.Get(json.Marshal(cc)) - if string(outJSON) != maxJSONConnCounts { - t.Errorf("JSON mismatch (-got +want):\n%s", cmp.Diff(string(outJSON), maxJSONConnCounts)) - } - - outCBOR := must.Get(cbor.Marshal(cc)) - maxCBORConnCountsAlt := "\xa7" + maxCBORConnCounts[1:len(maxCBORConnCounts)-1] // may use a definite encoding of map - if string(outCBOR) != maxCBORConnCounts && string(outCBOR) != maxCBORConnCountsAlt { - t.Errorf("CBOR mismatch (-got +want):\n%s", cmp.Diff(string(outCBOR), maxCBORConnCounts)) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netlogtype + +import ( + "encoding/json" + "math" + "net/netip" + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/google/go-cmp/cmp" + "tailscale.com/util/must" +) + +func TestMaxSize(t *testing.T) { + maxAddr := netip.AddrFrom16([16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}) + maxAddrPort := netip.AddrPortFrom(maxAddr, math.MaxUint16) + cc := ConnectionCounts{ + // NOTE: These composite literals are deliberately unkeyed so that + // added fields result in a build failure here. + // Newly added fields should result in an update to both + // MaxConnectionCountsJSONSize and MaxConnectionCountsCBORSize. + Connection{math.MaxUint8, maxAddrPort, maxAddrPort}, + Counts{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64}, + } + + outJSON := must.Get(json.Marshal(cc)) + if string(outJSON) != maxJSONConnCounts { + t.Errorf("JSON mismatch (-got +want):\n%s", cmp.Diff(string(outJSON), maxJSONConnCounts)) + } + + outCBOR := must.Get(cbor.Marshal(cc)) + maxCBORConnCountsAlt := "\xa7" + maxCBORConnCounts[1:len(maxCBORConnCounts)-1] // may use a definite encoding of map + if string(outCBOR) != maxCBORConnCounts && string(outCBOR) != maxCBORConnCountsAlt { + t.Errorf("CBOR mismatch (-got +want):\n%s", cmp.Diff(string(outCBOR), maxCBORConnCounts)) + } +} diff --git a/types/netmap/netmap_test.go b/types/netmap/netmap_test.go index 910b6bc21..e7e2d1957 100644 --- a/types/netmap/netmap_test.go +++ b/types/netmap/netmap_test.go @@ -1,318 +1,318 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netmap - -import ( - "encoding/hex" - "net/netip" - "testing" - - "go4.org/mem" - "tailscale.com/net/netaddr" - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -func testNodeKey(b byte) (ret key.NodePublic) { - var bs [key.NodePublicRawLen]byte - for i := range bs { - bs[i] = b - } - return key.NodePublicFromRaw32(mem.B(bs[:])) -} - -func testDiscoKey(hexPrefix string) (ret key.DiscoPublic) { - b, err := hex.DecodeString(hexPrefix) - if err != nil { - panic(err) - } - // this function is used with short hexes, so zero-extend the raw - // value. - var bs [32]byte - copy(bs[:], b) - return key.DiscoPublicFromRaw32(mem.B(bs[:])) -} - -func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView { - nv := make([]tailcfg.NodeView, len(v)) - for i, n := range v { - nv[i] = n.View() - } - return nv -} - -func eps(s ...string) []netip.AddrPort { - var eps []netip.AddrPort - for _, ep := range s { - eps = append(eps, netip.MustParseAddrPort(ep)) - } - return eps -} - -func TestNetworkMapConcise(t *testing.T) { - for _, tt := range []struct { - name string - nm *NetworkMap - want string - }{ - { - name: "basic", - nm: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - { - Key: testNodeKey(3), - DERP: "127.3.3.40:4", - Endpoints: eps("10.2.0.100:12", "10.1.0.100:12345"), - }, - }), - }, - want: "netmap: self: [AQEBA] auth=machine-unknown u=? []\n [AgICA] D2 : 192.168.0.100:12 192.168.0.100:12354\n [AwMDA] D4 : 10.2.0.100:12 10.1.0.100:12345\n", - }, - } { - t.Run(tt.name, func(t *testing.T) { - var got string - n := int(testing.AllocsPerRun(1000, func() { - got = tt.nm.Concise() - })) - t.Logf("Allocs = %d", n) - if got != tt.want { - t.Errorf("Wrong output\n Got: %q\nWant: %q\n## Got (unescaped):\n%s\n## Want (unescaped):\n%s\n", got, tt.want, got, tt.want) - } - }) - } -} - -func TestConciseDiffFrom(t *testing.T) { - for _, tt := range []struct { - name string - a, b *NetworkMap - want string - }{ - { - name: "no_change", - a: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - b: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - want: "", - }, - { - name: "header_change", - a: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - b: &NetworkMap{ - NodeKey: testNodeKey(2), - Peers: nodeViews([]*tailcfg.Node{ - { - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - want: "-netmap: self: [AQEBA] auth=machine-unknown u=? []\n+netmap: self: [AgICA] auth=machine-unknown u=? []\n", - }, - { - name: "peer_add", - a: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - b: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 1, - Key: testNodeKey(1), - DERP: "127.3.3.40:1", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - { - ID: 3, - Key: testNodeKey(3), - DERP: "127.3.3.40:3", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - want: "+ [AQEBA] D1 : 192.168.0.100:12 192.168.0.100:12354\n+ [AwMDA] D3 : 192.168.0.100:12 192.168.0.100:12354\n", - }, - { - name: "peer_remove", - a: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 1, - Key: testNodeKey(1), - DERP: "127.3.3.40:1", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - { - ID: 3, - Key: testNodeKey(3), - DERP: "127.3.3.40:3", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - b: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), - }, - }), - }, - want: "- [AQEBA] D1 : 192.168.0.100:12 192.168.0.100:12354\n- [AwMDA] D3 : 192.168.0.100:12 192.168.0.100:12354\n", - }, - { - name: "peer_port_change", - a: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "1.1.1.1:1"), - }, - }), - }, - b: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:12", "1.1.1.1:2"), - }, - }), - }, - want: "- [AgICA] D2 : 192.168.0.100:12 1.1.1.1:1 \n+ [AgICA] D2 : 192.168.0.100:12 1.1.1.1:2 \n", - }, - { - name: "disco_key_only_change", - a: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:41641", "1.1.1.1:41641"), - DiscoKey: testDiscoKey("f00f00f00f"), - AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 102, 103, 104), 32)}, - }, - }), - }, - b: &NetworkMap{ - NodeKey: testNodeKey(1), - Peers: nodeViews([]*tailcfg.Node{ - { - ID: 2, - Key: testNodeKey(2), - DERP: "127.3.3.40:2", - Endpoints: eps("192.168.0.100:41641", "1.1.1.1:41641"), - DiscoKey: testDiscoKey("ba4ba4ba4b"), - AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 102, 103, 104), 32)}, - }, - }), - }, - want: "- [AgICA] d:f00f00f00f000000 D2 100.102.103.104 : 192.168.0.100:41641 1.1.1.1:41641\n+ [AgICA] d:ba4ba4ba4b000000 D2 100.102.103.104 : 192.168.0.100:41641 1.1.1.1:41641\n", - }, - } { - t.Run(tt.name, func(t *testing.T) { - var got string - n := int(testing.AllocsPerRun(50, func() { - got = tt.b.ConciseDiffFrom(tt.a) - })) - t.Logf("Allocs = %d", n) - if got != tt.want { - t.Errorf("Wrong output\n Got: %q\nWant: %q\n## Got (unescaped):\n%s\n## Want (unescaped):\n%s\n", got, tt.want, got, tt.want) - } - }) - } -} - -func TestPeerIndexByNodeID(t *testing.T) { - var nilPtr *NetworkMap - if nilPtr.PeerIndexByNodeID(123) != -1 { - t.Errorf("nil PeerIndexByNodeID should return -1") - } - var nm NetworkMap - const min = 2 - const max = 10000 - const hole = max / 2 - for nid := tailcfg.NodeID(2); nid <= max; nid++ { - if nid == hole { - continue - } - nm.Peers = append(nm.Peers, (&tailcfg.Node{ID: nid}).View()) - } - for want, nv := range nm.Peers { - got := nm.PeerIndexByNodeID(nv.ID()) - if got != want { - t.Errorf("PeerIndexByNodeID(%v) = %v; want %v", nv.ID(), got, want) - } - } - for _, miss := range []tailcfg.NodeID{min - 1, hole, max + 1} { - if got := nm.PeerIndexByNodeID(miss); got != -1 { - t.Errorf("PeerIndexByNodeID(%v) = %v; want -1", miss, got) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netmap + +import ( + "encoding/hex" + "net/netip" + "testing" + + "go4.org/mem" + "tailscale.com/net/netaddr" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +func testNodeKey(b byte) (ret key.NodePublic) { + var bs [key.NodePublicRawLen]byte + for i := range bs { + bs[i] = b + } + return key.NodePublicFromRaw32(mem.B(bs[:])) +} + +func testDiscoKey(hexPrefix string) (ret key.DiscoPublic) { + b, err := hex.DecodeString(hexPrefix) + if err != nil { + panic(err) + } + // this function is used with short hexes, so zero-extend the raw + // value. + var bs [32]byte + copy(bs[:], b) + return key.DiscoPublicFromRaw32(mem.B(bs[:])) +} + +func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView { + nv := make([]tailcfg.NodeView, len(v)) + for i, n := range v { + nv[i] = n.View() + } + return nv +} + +func eps(s ...string) []netip.AddrPort { + var eps []netip.AddrPort + for _, ep := range s { + eps = append(eps, netip.MustParseAddrPort(ep)) + } + return eps +} + +func TestNetworkMapConcise(t *testing.T) { + for _, tt := range []struct { + name string + nm *NetworkMap + want string + }{ + { + name: "basic", + nm: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + { + Key: testNodeKey(3), + DERP: "127.3.3.40:4", + Endpoints: eps("10.2.0.100:12", "10.1.0.100:12345"), + }, + }), + }, + want: "netmap: self: [AQEBA] auth=machine-unknown u=? []\n [AgICA] D2 : 192.168.0.100:12 192.168.0.100:12354\n [AwMDA] D4 : 10.2.0.100:12 10.1.0.100:12345\n", + }, + } { + t.Run(tt.name, func(t *testing.T) { + var got string + n := int(testing.AllocsPerRun(1000, func() { + got = tt.nm.Concise() + })) + t.Logf("Allocs = %d", n) + if got != tt.want { + t.Errorf("Wrong output\n Got: %q\nWant: %q\n## Got (unescaped):\n%s\n## Want (unescaped):\n%s\n", got, tt.want, got, tt.want) + } + }) + } +} + +func TestConciseDiffFrom(t *testing.T) { + for _, tt := range []struct { + name string + a, b *NetworkMap + want string + }{ + { + name: "no_change", + a: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + b: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + want: "", + }, + { + name: "header_change", + a: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + b: &NetworkMap{ + NodeKey: testNodeKey(2), + Peers: nodeViews([]*tailcfg.Node{ + { + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + want: "-netmap: self: [AQEBA] auth=machine-unknown u=? []\n+netmap: self: [AgICA] auth=machine-unknown u=? []\n", + }, + { + name: "peer_add", + a: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + b: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 1, + Key: testNodeKey(1), + DERP: "127.3.3.40:1", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + { + ID: 3, + Key: testNodeKey(3), + DERP: "127.3.3.40:3", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + want: "+ [AQEBA] D1 : 192.168.0.100:12 192.168.0.100:12354\n+ [AwMDA] D3 : 192.168.0.100:12 192.168.0.100:12354\n", + }, + { + name: "peer_remove", + a: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 1, + Key: testNodeKey(1), + DERP: "127.3.3.40:1", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + { + ID: 3, + Key: testNodeKey(3), + DERP: "127.3.3.40:3", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + b: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), + }, + }), + }, + want: "- [AQEBA] D1 : 192.168.0.100:12 192.168.0.100:12354\n- [AwMDA] D3 : 192.168.0.100:12 192.168.0.100:12354\n", + }, + { + name: "peer_port_change", + a: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "1.1.1.1:1"), + }, + }), + }, + b: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:12", "1.1.1.1:2"), + }, + }), + }, + want: "- [AgICA] D2 : 192.168.0.100:12 1.1.1.1:1 \n+ [AgICA] D2 : 192.168.0.100:12 1.1.1.1:2 \n", + }, + { + name: "disco_key_only_change", + a: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:41641", "1.1.1.1:41641"), + DiscoKey: testDiscoKey("f00f00f00f"), + AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 102, 103, 104), 32)}, + }, + }), + }, + b: &NetworkMap{ + NodeKey: testNodeKey(1), + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 2, + Key: testNodeKey(2), + DERP: "127.3.3.40:2", + Endpoints: eps("192.168.0.100:41641", "1.1.1.1:41641"), + DiscoKey: testDiscoKey("ba4ba4ba4b"), + AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 102, 103, 104), 32)}, + }, + }), + }, + want: "- [AgICA] d:f00f00f00f000000 D2 100.102.103.104 : 192.168.0.100:41641 1.1.1.1:41641\n+ [AgICA] d:ba4ba4ba4b000000 D2 100.102.103.104 : 192.168.0.100:41641 1.1.1.1:41641\n", + }, + } { + t.Run(tt.name, func(t *testing.T) { + var got string + n := int(testing.AllocsPerRun(50, func() { + got = tt.b.ConciseDiffFrom(tt.a) + })) + t.Logf("Allocs = %d", n) + if got != tt.want { + t.Errorf("Wrong output\n Got: %q\nWant: %q\n## Got (unescaped):\n%s\n## Want (unescaped):\n%s\n", got, tt.want, got, tt.want) + } + }) + } +} + +func TestPeerIndexByNodeID(t *testing.T) { + var nilPtr *NetworkMap + if nilPtr.PeerIndexByNodeID(123) != -1 { + t.Errorf("nil PeerIndexByNodeID should return -1") + } + var nm NetworkMap + const min = 2 + const max = 10000 + const hole = max / 2 + for nid := tailcfg.NodeID(2); nid <= max; nid++ { + if nid == hole { + continue + } + nm.Peers = append(nm.Peers, (&tailcfg.Node{ID: nid}).View()) + } + for want, nv := range nm.Peers { + got := nm.PeerIndexByNodeID(nv.ID()) + if got != want { + t.Errorf("PeerIndexByNodeID(%v) = %v; want %v", nv.ID(), got, want) + } + } + for _, miss := range []tailcfg.NodeID{min - 1, hole, max + 1} { + if got := nm.PeerIndexByNodeID(miss); got != -1 { + t.Errorf("PeerIndexByNodeID(%v) = %v; want -1", miss, got) + } + } +} diff --git a/types/nettype/nettype.go b/types/nettype/nettype.go index 8930c36d8..5d3d303c3 100644 --- a/types/nettype/nettype.go +++ b/types/nettype/nettype.go @@ -1,65 +1,65 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package nettype defines an interface that doesn't exist in the Go net package. -package nettype - -import ( - "context" - "io" - "net" - "net/netip" - "time" -) - -// PacketListener defines the ListenPacket method as implemented -// by net.ListenConfig, net.ListenPacket, and tstest/natlab. -type PacketListener interface { - ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) -} - -type PacketListenerWithNetIP interface { - ListenPacket(ctx context.Context, network, address string) (PacketConn, error) -} - -// Std implements PacketListener using the Go net package's ListenPacket func. -type Std struct{} - -func (Std) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { - var conf net.ListenConfig - return conf.ListenPacket(ctx, network, address) -} - -// PacketConn is like a net.PacketConn but uses the newer netip.AddrPort -// write/read methods. -type PacketConn interface { - WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) - ReadFromUDPAddrPort([]byte) (int, netip.AddrPort, error) - io.Closer - LocalAddr() net.Addr - SetDeadline(time.Time) error - SetReadDeadline(time.Time) error - SetWriteDeadline(time.Time) error -} - -func MakePacketListenerWithNetIP(ln PacketListener) PacketListenerWithNetIP { - return packetListenerAdapter{ln} -} - -type packetListenerAdapter struct { - PacketListener -} - -func (a packetListenerAdapter) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) { - pc, err := a.PacketListener.ListenPacket(ctx, network, address) - if err != nil { - return nil, err - } - return pc.(PacketConn), nil -} - -// ConnPacketConn is the interface that's a superset of net.Conn and net.PacketConn. -type ConnPacketConn interface { - net.Conn - net.PacketConn -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package nettype defines an interface that doesn't exist in the Go net package. +package nettype + +import ( + "context" + "io" + "net" + "net/netip" + "time" +) + +// PacketListener defines the ListenPacket method as implemented +// by net.ListenConfig, net.ListenPacket, and tstest/natlab. +type PacketListener interface { + ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) +} + +type PacketListenerWithNetIP interface { + ListenPacket(ctx context.Context, network, address string) (PacketConn, error) +} + +// Std implements PacketListener using the Go net package's ListenPacket func. +type Std struct{} + +func (Std) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + var conf net.ListenConfig + return conf.ListenPacket(ctx, network, address) +} + +// PacketConn is like a net.PacketConn but uses the newer netip.AddrPort +// write/read methods. +type PacketConn interface { + WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) + ReadFromUDPAddrPort([]byte) (int, netip.AddrPort, error) + io.Closer + LocalAddr() net.Addr + SetDeadline(time.Time) error + SetReadDeadline(time.Time) error + SetWriteDeadline(time.Time) error +} + +func MakePacketListenerWithNetIP(ln PacketListener) PacketListenerWithNetIP { + return packetListenerAdapter{ln} +} + +type packetListenerAdapter struct { + PacketListener +} + +func (a packetListenerAdapter) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) { + pc, err := a.PacketListener.ListenPacket(ctx, network, address) + if err != nil { + return nil, err + } + return pc.(PacketConn), nil +} + +// ConnPacketConn is the interface that's a superset of net.Conn and net.PacketConn. +type ConnPacketConn interface { + net.Conn + net.PacketConn +} diff --git a/types/preftype/netfiltermode.go b/types/preftype/netfiltermode.go index 5756e5096..273e17344 100644 --- a/types/preftype/netfiltermode.go +++ b/types/preftype/netfiltermode.go @@ -1,46 +1,46 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package preftype is a leaf package containing types for various -// preferences. -package preftype - -import "fmt" - -// NetfilterMode is the firewall management mode to use when -// programming the Linux network stack. -type NetfilterMode int - -// These numbers are persisted to disk in JSON files and thus can't be -// renumbered or repurposed. -const ( - NetfilterOff NetfilterMode = 0 // remove all tailscale netfilter state - NetfilterNoDivert NetfilterMode = 1 // manage tailscale chains, but don't call them - NetfilterOn NetfilterMode = 2 // manage tailscale chains and call them from main chains -) - -func ParseNetfilterMode(s string) (NetfilterMode, error) { - switch s { - case "off": - return NetfilterOff, nil - case "nodivert": - return NetfilterNoDivert, nil - case "on": - return NetfilterOn, nil - default: - return NetfilterOff, fmt.Errorf("unknown netfilter mode %q", s) - } -} - -func (m NetfilterMode) String() string { - switch m { - case NetfilterOff: - return "off" - case NetfilterNoDivert: - return "nodivert" - case NetfilterOn: - return "on" - default: - return "???" - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package preftype is a leaf package containing types for various +// preferences. +package preftype + +import "fmt" + +// NetfilterMode is the firewall management mode to use when +// programming the Linux network stack. +type NetfilterMode int + +// These numbers are persisted to disk in JSON files and thus can't be +// renumbered or repurposed. +const ( + NetfilterOff NetfilterMode = 0 // remove all tailscale netfilter state + NetfilterNoDivert NetfilterMode = 1 // manage tailscale chains, but don't call them + NetfilterOn NetfilterMode = 2 // manage tailscale chains and call them from main chains +) + +func ParseNetfilterMode(s string) (NetfilterMode, error) { + switch s { + case "off": + return NetfilterOff, nil + case "nodivert": + return NetfilterNoDivert, nil + case "on": + return NetfilterOn, nil + default: + return NetfilterOff, fmt.Errorf("unknown netfilter mode %q", s) + } +} + +func (m NetfilterMode) String() string { + switch m { + case NetfilterOff: + return "off" + case NetfilterNoDivert: + return "nodivert" + case NetfilterOn: + return "on" + default: + return "???" + } +} diff --git a/types/ptr/ptr.go b/types/ptr/ptr.go index beb955bf0..beb17bee8 100644 --- a/types/ptr/ptr.go +++ b/types/ptr/ptr.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package ptr contains the ptr.To function. -package ptr - -// To returns a pointer to a shallow copy of v. -func To[T any](v T) *T { - return &v -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ptr contains the ptr.To function. +package ptr + +// To returns a pointer to a shallow copy of v. +func To[T any](v T) *T { + return &v +} diff --git a/types/structs/structs.go b/types/structs/structs.go index bac6b2991..47c359f0c 100644 --- a/types/structs/structs.go +++ b/types/structs/structs.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package structs contains the Incomparable type. -package structs - -// Incomparable is a zero-width incomparable type. If added as the -// first field in a struct, it marks that struct as not comparable -// (can't do == or be a map key) and usually doesn't add any width to -// the struct (unless the struct has only small fields). -// -// Be making a struct incomparable, you can prevent misuse (prevent -// people from using ==), but also you can shrink generated binaries, -// as the compiler can omit equality funcs from the binary. -type Incomparable [0]func() +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package structs contains the Incomparable type. +package structs + +// Incomparable is a zero-width incomparable type. If added as the +// first field in a struct, it marks that struct as not comparable +// (can't do == or be a map key) and usually doesn't add any width to +// the struct (unless the struct has only small fields). +// +// Be making a struct incomparable, you can prevent misuse (prevent +// people from using ==), but also you can shrink generated binaries, +// as the compiler can omit equality funcs from the binary. +type Incomparable [0]func() diff --git a/types/tkatype/tkatype.go b/types/tkatype/tkatype.go index aca6f1443..6ad51f6a9 100644 --- a/types/tkatype/tkatype.go +++ b/types/tkatype/tkatype.go @@ -1,40 +1,40 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package tkatype defines types for working with the tka package. -// -// Do not add extra dependencies to this package unless they are tiny, -// because this package encodes wire types that should be lightweight to use. -package tkatype - -// KeyID references a verification key stored in the key authority. A keyID -// uniquely identifies a key. KeyIDs are all 32 bytes. -// -// For 25519 keys: We just use the 32-byte public key. -// -// Even though this is a 32-byte value, we use a byte slice because -// CBOR-encoded byte slices have a different prefix to CBOR-encoded arrays. -// Encoding as a byte slice allows us to change the size in the future if we -// ever need to. -type KeyID []byte - -// MarshaledSignature represents a marshaled tka.NodeKeySignature. -type MarshaledSignature []byte - -// MarshaledAUM represents a marshaled tka.AUM. -type MarshaledAUM []byte - -// AUMSigHash represents the BLAKE2s digest of an Authority Update -// Message (AUM), sans any signatures. -type AUMSigHash [32]byte - -// NKSSigHash represents the BLAKE2s digest of a Node-Key Signature (NKS), -// sans the Signature field if present. -type NKSSigHash [32]byte - -// Signature describes a signature over an AUM, which can be verified -// using the key referenced by KeyID. -type Signature struct { - KeyID KeyID `cbor:"1,keyasint"` - Signature []byte `cbor:"2,keyasint"` -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tkatype defines types for working with the tka package. +// +// Do not add extra dependencies to this package unless they are tiny, +// because this package encodes wire types that should be lightweight to use. +package tkatype + +// KeyID references a verification key stored in the key authority. A keyID +// uniquely identifies a key. KeyIDs are all 32 bytes. +// +// For 25519 keys: We just use the 32-byte public key. +// +// Even though this is a 32-byte value, we use a byte slice because +// CBOR-encoded byte slices have a different prefix to CBOR-encoded arrays. +// Encoding as a byte slice allows us to change the size in the future if we +// ever need to. +type KeyID []byte + +// MarshaledSignature represents a marshaled tka.NodeKeySignature. +type MarshaledSignature []byte + +// MarshaledAUM represents a marshaled tka.AUM. +type MarshaledAUM []byte + +// AUMSigHash represents the BLAKE2s digest of an Authority Update +// Message (AUM), sans any signatures. +type AUMSigHash [32]byte + +// NKSSigHash represents the BLAKE2s digest of a Node-Key Signature (NKS), +// sans the Signature field if present. +type NKSSigHash [32]byte + +// Signature describes a signature over an AUM, which can be verified +// using the key referenced by KeyID. +type Signature struct { + KeyID KeyID `cbor:"1,keyasint"` + Signature []byte `cbor:"2,keyasint"` +} diff --git a/types/tkatype/tkatype_test.go b/types/tkatype/tkatype_test.go index bff908072..c81891b9c 100644 --- a/types/tkatype/tkatype_test.go +++ b/types/tkatype/tkatype_test.go @@ -1,43 +1,43 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tkatype - -import ( - "encoding/json" - "testing" - - "golang.org/x/crypto/blake2s" -) - -func TestSigHashSize(t *testing.T) { - var sigHash AUMSigHash - if len(sigHash) != blake2s.Size { - t.Errorf("AUMSigHash is wrong size: got %d, want %d", len(sigHash), blake2s.Size) - } - - var nksHash NKSSigHash - if len(nksHash) != blake2s.Size { - t.Errorf("NKSSigHash is wrong size: got %d, want %d", len(nksHash), blake2s.Size) - } -} - -func TestMarshaledSignatureJSON(t *testing.T) { - sig := MarshaledSignature("abcdef") - j, err := json.Marshal(sig) - if err != nil { - t.Fatal(err) - } - const encoded = `"YWJjZGVm"` - if string(j) != encoded { - t.Errorf("got JSON %q; want %q", j, encoded) - } - - var back MarshaledSignature - if err := json.Unmarshal([]byte(encoded), &back); err != nil { - t.Fatal(err) - } - if string(back) != string(sig) { - t.Errorf("decoded JSON back to %q; want %q", back, sig) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tkatype + +import ( + "encoding/json" + "testing" + + "golang.org/x/crypto/blake2s" +) + +func TestSigHashSize(t *testing.T) { + var sigHash AUMSigHash + if len(sigHash) != blake2s.Size { + t.Errorf("AUMSigHash is wrong size: got %d, want %d", len(sigHash), blake2s.Size) + } + + var nksHash NKSSigHash + if len(nksHash) != blake2s.Size { + t.Errorf("NKSSigHash is wrong size: got %d, want %d", len(nksHash), blake2s.Size) + } +} + +func TestMarshaledSignatureJSON(t *testing.T) { + sig := MarshaledSignature("abcdef") + j, err := json.Marshal(sig) + if err != nil { + t.Fatal(err) + } + const encoded = `"YWJjZGVm"` + if string(j) != encoded { + t.Errorf("got JSON %q; want %q", j, encoded) + } + + var back MarshaledSignature + if err := json.Unmarshal([]byte(encoded), &back); err != nil { + t.Fatal(err) + } + if string(back) != string(sig) { + t.Errorf("decoded JSON back to %q; want %q", back, sig) + } +} diff --git a/util/cibuild/cibuild.go b/util/cibuild/cibuild.go index c3dee6154..c1e337f9a 100644 --- a/util/cibuild/cibuild.go +++ b/util/cibuild/cibuild.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package cibuild reports runtime CI information. -package cibuild - -import "os" - -// On reports whether the current binary is executing on a CI system. -func On() bool { - // CI env variable is set by GitHub. - // https://docs.github.com/en/actions/learn-github-actions/environment-variables#default-environment-variables - return os.Getenv("GITHUB_ACTIONS") != "" || os.Getenv("CI") == "true" -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package cibuild reports runtime CI information. +package cibuild + +import "os" + +// On reports whether the current binary is executing on a CI system. +func On() bool { + // CI env variable is set by GitHub. + // https://docs.github.com/en/actions/learn-github-actions/environment-variables#default-environment-variables + return os.Getenv("GITHUB_ACTIONS") != "" || os.Getenv("CI") == "true" +} diff --git a/util/cstruct/cstruct.go b/util/cstruct/cstruct.go index e32c90830..464dc5dc3 100644 --- a/util/cstruct/cstruct.go +++ b/util/cstruct/cstruct.go @@ -1,178 +1,178 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package cstruct provides a helper for decoding binary data that is in the -// form of a padded C structure. -package cstruct - -import ( - "errors" - "io" - - "github.com/josharian/native" -) - -// Size of a pointer-typed value, in bits -const pointerSize = 32 << (^uintptr(0) >> 63) - -// We assume that non-64-bit platforms are 32-bit; we don't expect Go to run on -// a 16- or 8-bit architecture any time soon. -const is64Bit = pointerSize == 64 - -// Decoder reads and decodes padded fields from a slice of bytes. All fields -// are decoded with native endianness. -// -// Methods of a Decoder do not return errors, but rather store any error within -// the Decoder. The first error can be obtained via the Err method; after the -// first error, methods will return the zero value for their type. -type Decoder struct { - b []byte - off int - err error - dbuf [8]byte // for decoding -} - -// NewDecoder creates a Decoder from a byte slice. -func NewDecoder(b []byte) *Decoder { - return &Decoder{b: b} -} - -var errUnsupportedSize = errors.New("unsupported size") - -func padBytes(offset, size int) int { - if offset == 0 || size == 1 { - return 0 - } - remainder := offset % size - return size - remainder -} - -func (d *Decoder) getField(b []byte) error { - size := len(b) - - // We only support fields that are multiples of 2 (or 1-sized) - if size != 1 && size&1 == 1 { - return errUnsupportedSize - } - - // Fields are aligned to their size - padBytes := padBytes(d.off, size) - if d.off+size+padBytes > len(d.b) { - return io.EOF - } - d.off += padBytes - - copy(b, d.b[d.off:d.off+size]) - d.off += size - return nil -} - -// Err returns the first error that was encountered by this Decoder. -func (d *Decoder) Err() error { - return d.err -} - -// Offset returns the current read offset for data in the buffer. -func (d *Decoder) Offset() int { - return d.off -} - -// Byte returns a single byte from the buffer. -func (d *Decoder) Byte() byte { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:1]); err != nil { - d.err = err - return 0 - } - return d.dbuf[0] -} - -// Byte returns a number of bytes from the buffer based on the size of the -// input slice. No padding is applied. -// -// If an error is encountered or this Decoder has previously encountered an -// error, no changes are made to the provided buffer. -func (d *Decoder) Bytes(b []byte) { - if d.err != nil { - return - } - - // No padding for byte slices - size := len(b) - if d.off+size >= len(d.b) { - d.err = io.EOF - return - } - copy(b, d.b[d.off:d.off+size]) - d.off += size -} - -// Uint16 returns a uint16 decoded from the buffer. -func (d *Decoder) Uint16() uint16 { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:2]); err != nil { - d.err = err - return 0 - } - return native.Endian.Uint16(d.dbuf[0:2]) -} - -// Uint32 returns a uint32 decoded from the buffer. -func (d *Decoder) Uint32() uint32 { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:4]); err != nil { - d.err = err - return 0 - } - return native.Endian.Uint32(d.dbuf[0:4]) -} - -// Uint64 returns a uint64 decoded from the buffer. -func (d *Decoder) Uint64() uint64 { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:8]); err != nil { - d.err = err - return 0 - } - return native.Endian.Uint64(d.dbuf[0:8]) -} - -// Uintptr returns a uintptr decoded from the buffer. -func (d *Decoder) Uintptr() uintptr { - if d.err != nil { - return 0 - } - - if is64Bit { - return uintptr(d.Uint64()) - } else { - return uintptr(d.Uint32()) - } -} - -// Int16 returns a int16 decoded from the buffer. -func (d *Decoder) Int16() int16 { - return int16(d.Uint16()) -} - -// Int32 returns a int32 decoded from the buffer. -func (d *Decoder) Int32() int32 { - return int32(d.Uint32()) -} - -// Int64 returns a int64 decoded from the buffer. -func (d *Decoder) Int64() int64 { - return int64(d.Uint64()) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package cstruct provides a helper for decoding binary data that is in the +// form of a padded C structure. +package cstruct + +import ( + "errors" + "io" + + "github.com/josharian/native" +) + +// Size of a pointer-typed value, in bits +const pointerSize = 32 << (^uintptr(0) >> 63) + +// We assume that non-64-bit platforms are 32-bit; we don't expect Go to run on +// a 16- or 8-bit architecture any time soon. +const is64Bit = pointerSize == 64 + +// Decoder reads and decodes padded fields from a slice of bytes. All fields +// are decoded with native endianness. +// +// Methods of a Decoder do not return errors, but rather store any error within +// the Decoder. The first error can be obtained via the Err method; after the +// first error, methods will return the zero value for their type. +type Decoder struct { + b []byte + off int + err error + dbuf [8]byte // for decoding +} + +// NewDecoder creates a Decoder from a byte slice. +func NewDecoder(b []byte) *Decoder { + return &Decoder{b: b} +} + +var errUnsupportedSize = errors.New("unsupported size") + +func padBytes(offset, size int) int { + if offset == 0 || size == 1 { + return 0 + } + remainder := offset % size + return size - remainder +} + +func (d *Decoder) getField(b []byte) error { + size := len(b) + + // We only support fields that are multiples of 2 (or 1-sized) + if size != 1 && size&1 == 1 { + return errUnsupportedSize + } + + // Fields are aligned to their size + padBytes := padBytes(d.off, size) + if d.off+size+padBytes > len(d.b) { + return io.EOF + } + d.off += padBytes + + copy(b, d.b[d.off:d.off+size]) + d.off += size + return nil +} + +// Err returns the first error that was encountered by this Decoder. +func (d *Decoder) Err() error { + return d.err +} + +// Offset returns the current read offset for data in the buffer. +func (d *Decoder) Offset() int { + return d.off +} + +// Byte returns a single byte from the buffer. +func (d *Decoder) Byte() byte { + if d.err != nil { + return 0 + } + + if err := d.getField(d.dbuf[0:1]); err != nil { + d.err = err + return 0 + } + return d.dbuf[0] +} + +// Byte returns a number of bytes from the buffer based on the size of the +// input slice. No padding is applied. +// +// If an error is encountered or this Decoder has previously encountered an +// error, no changes are made to the provided buffer. +func (d *Decoder) Bytes(b []byte) { + if d.err != nil { + return + } + + // No padding for byte slices + size := len(b) + if d.off+size >= len(d.b) { + d.err = io.EOF + return + } + copy(b, d.b[d.off:d.off+size]) + d.off += size +} + +// Uint16 returns a uint16 decoded from the buffer. +func (d *Decoder) Uint16() uint16 { + if d.err != nil { + return 0 + } + + if err := d.getField(d.dbuf[0:2]); err != nil { + d.err = err + return 0 + } + return native.Endian.Uint16(d.dbuf[0:2]) +} + +// Uint32 returns a uint32 decoded from the buffer. +func (d *Decoder) Uint32() uint32 { + if d.err != nil { + return 0 + } + + if err := d.getField(d.dbuf[0:4]); err != nil { + d.err = err + return 0 + } + return native.Endian.Uint32(d.dbuf[0:4]) +} + +// Uint64 returns a uint64 decoded from the buffer. +func (d *Decoder) Uint64() uint64 { + if d.err != nil { + return 0 + } + + if err := d.getField(d.dbuf[0:8]); err != nil { + d.err = err + return 0 + } + return native.Endian.Uint64(d.dbuf[0:8]) +} + +// Uintptr returns a uintptr decoded from the buffer. +func (d *Decoder) Uintptr() uintptr { + if d.err != nil { + return 0 + } + + if is64Bit { + return uintptr(d.Uint64()) + } else { + return uintptr(d.Uint32()) + } +} + +// Int16 returns a int16 decoded from the buffer. +func (d *Decoder) Int16() int16 { + return int16(d.Uint16()) +} + +// Int32 returns a int32 decoded from the buffer. +func (d *Decoder) Int32() int32 { + return int32(d.Uint32()) +} + +// Int64 returns a int64 decoded from the buffer. +func (d *Decoder) Int64() int64 { + return int64(d.Uint64()) +} diff --git a/util/cstruct/cstruct_example_test.go b/util/cstruct/cstruct_example_test.go index a36cbf9f0..17032267b 100644 --- a/util/cstruct/cstruct_example_test.go +++ b/util/cstruct/cstruct_example_test.go @@ -1,73 +1,73 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Only built on 64-bit platforms to avoid complexity - -//go:build amd64 || arm64 || mips64le || ppc64le || riscv64 - -package cstruct - -import "fmt" - -// This test provides a semi-realistic example of how you can -// use this package to decode a C structure. -func ExampleDecoder() { - // Our example C structure: - // struct mystruct { - // char *p; - // char c; - // /* implicit: char _pad[3]; */ - // int x; - // }; - // - // The Go structure definition: - type myStruct struct { - Ptr uintptr - Ch byte - Intval uint32 - } - - // Our "in-memory" version of the above structure - buf := []byte{ - 1, 2, 3, 4, 0, 0, 0, 0, // ptr - 5, // ch - 99, 99, 99, // padding - 78, 6, 0, 0, // x - } - d := NewDecoder(buf) - - // Decode the structure; if one of these function returns an error, - // then subsequent decoder functions will return the zero value. - var x myStruct - x.Ptr = d.Uintptr() - x.Ch = d.Byte() - x.Intval = d.Uint32() - - // Note that per the Go language spec: - // [...] when evaluating the operands of an expression, assignment, - // or return statement, all function calls, method calls, and - // (channel) communication operations are evaluated in lexical - // left-to-right order - // - // Since each field is assigned via a function call, one could use the - // following snippet to decode the struct. - // x := myStruct{ - // Ptr: d.Uintptr(), - // Ch: d.Byte(), - // Intval: d.Uint32(), - // } - // - // However, this means that reordering the fields in the initialization - // statement–normally a semantically identical operation–would change - // the way the structure is parsed. Thus we do it as above with - // explicit ordering. - - // After finishing with the decoder, check errors - if err := d.Err(); err != nil { - panic(err) - } - - // Print the decoder offset and structure - fmt.Printf("off=%d struct=%#v\n", d.Offset(), x) - // Output: off=16 struct=cstruct.myStruct{Ptr:0x4030201, Ch:0x5, Intval:0x64e} -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Only built on 64-bit platforms to avoid complexity + +//go:build amd64 || arm64 || mips64le || ppc64le || riscv64 + +package cstruct + +import "fmt" + +// This test provides a semi-realistic example of how you can +// use this package to decode a C structure. +func ExampleDecoder() { + // Our example C structure: + // struct mystruct { + // char *p; + // char c; + // /* implicit: char _pad[3]; */ + // int x; + // }; + // + // The Go structure definition: + type myStruct struct { + Ptr uintptr + Ch byte + Intval uint32 + } + + // Our "in-memory" version of the above structure + buf := []byte{ + 1, 2, 3, 4, 0, 0, 0, 0, // ptr + 5, // ch + 99, 99, 99, // padding + 78, 6, 0, 0, // x + } + d := NewDecoder(buf) + + // Decode the structure; if one of these function returns an error, + // then subsequent decoder functions will return the zero value. + var x myStruct + x.Ptr = d.Uintptr() + x.Ch = d.Byte() + x.Intval = d.Uint32() + + // Note that per the Go language spec: + // [...] when evaluating the operands of an expression, assignment, + // or return statement, all function calls, method calls, and + // (channel) communication operations are evaluated in lexical + // left-to-right order + // + // Since each field is assigned via a function call, one could use the + // following snippet to decode the struct. + // x := myStruct{ + // Ptr: d.Uintptr(), + // Ch: d.Byte(), + // Intval: d.Uint32(), + // } + // + // However, this means that reordering the fields in the initialization + // statement–normally a semantically identical operation–would change + // the way the structure is parsed. Thus we do it as above with + // explicit ordering. + + // After finishing with the decoder, check errors + if err := d.Err(); err != nil { + panic(err) + } + + // Print the decoder offset and structure + fmt.Printf("off=%d struct=%#v\n", d.Offset(), x) + // Output: off=16 struct=cstruct.myStruct{Ptr:0x4030201, Ch:0x5, Intval:0x64e} +} diff --git a/util/deephash/debug.go b/util/deephash/debug.go index ff417e583..50b3d5605 100644 --- a/util/deephash/debug.go +++ b/util/deephash/debug.go @@ -1,37 +1,37 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build deephash_debug - -package deephash - -import "fmt" - -func (h *hasher) HashBytes(b []byte) { - fmt.Printf("B(%q)+", b) - h.Block512.HashBytes(b) -} -func (h *hasher) HashString(s string) { - fmt.Printf("S(%q)+", s) - h.Block512.HashString(s) -} -func (h *hasher) HashUint8(n uint8) { - fmt.Printf("U8(%d)+", n) - h.Block512.HashUint8(n) -} -func (h *hasher) HashUint16(n uint16) { - fmt.Printf("U16(%d)+", n) - h.Block512.HashUint16(n) -} -func (h *hasher) HashUint32(n uint32) { - fmt.Printf("U32(%d)+", n) - h.Block512.HashUint32(n) -} -func (h *hasher) HashUint64(n uint64) { - fmt.Printf("U64(%d)+", n) - h.Block512.HashUint64(n) -} -func (h *hasher) Sum(b []byte) []byte { - fmt.Println("FIN") - return h.Block512.Sum(b) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build deephash_debug + +package deephash + +import "fmt" + +func (h *hasher) HashBytes(b []byte) { + fmt.Printf("B(%q)+", b) + h.Block512.HashBytes(b) +} +func (h *hasher) HashString(s string) { + fmt.Printf("S(%q)+", s) + h.Block512.HashString(s) +} +func (h *hasher) HashUint8(n uint8) { + fmt.Printf("U8(%d)+", n) + h.Block512.HashUint8(n) +} +func (h *hasher) HashUint16(n uint16) { + fmt.Printf("U16(%d)+", n) + h.Block512.HashUint16(n) +} +func (h *hasher) HashUint32(n uint32) { + fmt.Printf("U32(%d)+", n) + h.Block512.HashUint32(n) +} +func (h *hasher) HashUint64(n uint64) { + fmt.Printf("U64(%d)+", n) + h.Block512.HashUint64(n) +} +func (h *hasher) Sum(b []byte) []byte { + fmt.Println("FIN") + return h.Block512.Sum(b) +} diff --git a/util/deephash/pointer.go b/util/deephash/pointer.go index 71b11d7ff..aafae47a2 100644 --- a/util/deephash/pointer.go +++ b/util/deephash/pointer.go @@ -1,114 +1,114 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package deephash - -import ( - "net/netip" - "reflect" - "time" - "unsafe" -) - -// unsafePointer is an untyped pointer. -// It is the caller's responsibility to call operations on the correct type. -// -// This pointer only ever points to a small set of kinds or types: -// time.Time, netip.Addr, string, array, slice, struct, map, pointer, interface, -// or a pointer to memory that is directly hashable. -// -// Arrays are represented as pointers to the first element. -// Structs are represented as pointers to the first field. -// Slices are represented as pointers to a slice header. -// Pointers are represented as pointers to a pointer. -// -// We do not support direct operations on maps and interfaces, and instead -// rely on pointer.asValue to convert the pointer back to a reflect.Value. -// Conversion of an unsafe.Pointer to reflect.Value guarantees that the -// read-only flag in the reflect.Value is unpopulated, avoiding panics that may -// otherwise have occurred since the value was obtained from an unexported field. -type unsafePointer struct{ p unsafe.Pointer } - -func unsafePointerOf(v reflect.Value) unsafePointer { - return unsafePointer{v.UnsafePointer()} -} -func (p unsafePointer) isNil() bool { - return p.p == nil -} - -// pointerElem dereferences a pointer. -// p must point to a pointer. -func (p unsafePointer) pointerElem() unsafePointer { - return unsafePointer{*(*unsafe.Pointer)(p.p)} -} - -// sliceLen returns the slice length. -// p must point to a slice. -func (p unsafePointer) sliceLen() int { - return (*reflect.SliceHeader)(p.p).Len -} - -// sliceArray returns a pointer to the underlying slice array. -// p must point to a slice. -func (p unsafePointer) sliceArray() unsafePointer { - return unsafePointer{unsafe.Pointer((*reflect.SliceHeader)(p.p).Data)} -} - -// arrayIndex returns a pointer to an element in the array. -// p must point to an array. -func (p unsafePointer) arrayIndex(index int, size uintptr) unsafePointer { - return unsafePointer{unsafe.Add(p.p, uintptr(index)*size)} -} - -// structField returns a pointer to a field in a struct. -// p must pointer to a struct. -func (p unsafePointer) structField(index int, offset, size uintptr) unsafePointer { - return unsafePointer{unsafe.Add(p.p, offset)} -} - -// asString casts p as a *string. -func (p unsafePointer) asString() *string { - return (*string)(p.p) -} - -// asTime casts p as a *time.Time. -func (p unsafePointer) asTime() *time.Time { - return (*time.Time)(p.p) -} - -// asAddr casts p as a *netip.Addr. -func (p unsafePointer) asAddr() *netip.Addr { - return (*netip.Addr)(p.p) -} - -// asValue casts p as a reflect.Value containing a pointer to value of t. -func (p unsafePointer) asValue(typ reflect.Type) reflect.Value { - return reflect.NewAt(typ, p.p) -} - -// asMemory returns the memory pointer at by p for a specified size. -func (p unsafePointer) asMemory(size uintptr) []byte { - return unsafe.Slice((*byte)(p.p), size) -} - -// visitStack is a stack of pointers visited. -// Pointers are pushed onto the stack when visited, and popped when leaving. -// The integer value is the depth at which the pointer was visited. -// The length of this stack should be zero after every hashing operation. -type visitStack map[unsafe.Pointer]int - -func (v visitStack) seen(p unsafe.Pointer) (int, bool) { - idx, ok := v[p] - return idx, ok -} - -func (v *visitStack) push(p unsafe.Pointer) { - if *v == nil { - *v = make(map[unsafe.Pointer]int) - } - (*v)[p] = len(*v) -} - -func (v visitStack) pop(p unsafe.Pointer) { - delete(v, p) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package deephash + +import ( + "net/netip" + "reflect" + "time" + "unsafe" +) + +// unsafePointer is an untyped pointer. +// It is the caller's responsibility to call operations on the correct type. +// +// This pointer only ever points to a small set of kinds or types: +// time.Time, netip.Addr, string, array, slice, struct, map, pointer, interface, +// or a pointer to memory that is directly hashable. +// +// Arrays are represented as pointers to the first element. +// Structs are represented as pointers to the first field. +// Slices are represented as pointers to a slice header. +// Pointers are represented as pointers to a pointer. +// +// We do not support direct operations on maps and interfaces, and instead +// rely on pointer.asValue to convert the pointer back to a reflect.Value. +// Conversion of an unsafe.Pointer to reflect.Value guarantees that the +// read-only flag in the reflect.Value is unpopulated, avoiding panics that may +// otherwise have occurred since the value was obtained from an unexported field. +type unsafePointer struct{ p unsafe.Pointer } + +func unsafePointerOf(v reflect.Value) unsafePointer { + return unsafePointer{v.UnsafePointer()} +} +func (p unsafePointer) isNil() bool { + return p.p == nil +} + +// pointerElem dereferences a pointer. +// p must point to a pointer. +func (p unsafePointer) pointerElem() unsafePointer { + return unsafePointer{*(*unsafe.Pointer)(p.p)} +} + +// sliceLen returns the slice length. +// p must point to a slice. +func (p unsafePointer) sliceLen() int { + return (*reflect.SliceHeader)(p.p).Len +} + +// sliceArray returns a pointer to the underlying slice array. +// p must point to a slice. +func (p unsafePointer) sliceArray() unsafePointer { + return unsafePointer{unsafe.Pointer((*reflect.SliceHeader)(p.p).Data)} +} + +// arrayIndex returns a pointer to an element in the array. +// p must point to an array. +func (p unsafePointer) arrayIndex(index int, size uintptr) unsafePointer { + return unsafePointer{unsafe.Add(p.p, uintptr(index)*size)} +} + +// structField returns a pointer to a field in a struct. +// p must pointer to a struct. +func (p unsafePointer) structField(index int, offset, size uintptr) unsafePointer { + return unsafePointer{unsafe.Add(p.p, offset)} +} + +// asString casts p as a *string. +func (p unsafePointer) asString() *string { + return (*string)(p.p) +} + +// asTime casts p as a *time.Time. +func (p unsafePointer) asTime() *time.Time { + return (*time.Time)(p.p) +} + +// asAddr casts p as a *netip.Addr. +func (p unsafePointer) asAddr() *netip.Addr { + return (*netip.Addr)(p.p) +} + +// asValue casts p as a reflect.Value containing a pointer to value of t. +func (p unsafePointer) asValue(typ reflect.Type) reflect.Value { + return reflect.NewAt(typ, p.p) +} + +// asMemory returns the memory pointer at by p for a specified size. +func (p unsafePointer) asMemory(size uintptr) []byte { + return unsafe.Slice((*byte)(p.p), size) +} + +// visitStack is a stack of pointers visited. +// Pointers are pushed onto the stack when visited, and popped when leaving. +// The integer value is the depth at which the pointer was visited. +// The length of this stack should be zero after every hashing operation. +type visitStack map[unsafe.Pointer]int + +func (v visitStack) seen(p unsafe.Pointer) (int, bool) { + idx, ok := v[p] + return idx, ok +} + +func (v *visitStack) push(p unsafe.Pointer) { + if *v == nil { + *v = make(map[unsafe.Pointer]int) + } + (*v)[p] = len(*v) +} + +func (v visitStack) pop(p unsafe.Pointer) { + delete(v, p) +} diff --git a/util/deephash/pointer_norace.go b/util/deephash/pointer_norace.go index 499372000..f98a70f6a 100644 --- a/util/deephash/pointer_norace.go +++ b/util/deephash/pointer_norace.go @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !race - -package deephash - -import "reflect" - -type pointer = unsafePointer - -// pointerOf returns a pointer from v, which must be a reflect.Pointer. -func pointerOf(v reflect.Value) pointer { return unsafePointerOf(v) } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !race + +package deephash + +import "reflect" + +type pointer = unsafePointer + +// pointerOf returns a pointer from v, which must be a reflect.Pointer. +func pointerOf(v reflect.Value) pointer { return unsafePointerOf(v) } diff --git a/util/deephash/pointer_race.go b/util/deephash/pointer_race.go index 93a358b6d..c638c7d39 100644 --- a/util/deephash/pointer_race.go +++ b/util/deephash/pointer_race.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build race - -package deephash - -import ( - "fmt" - "net/netip" - "reflect" - "time" -) - -// pointer is a typed pointer that performs safety checks for every operation. -type pointer struct { - unsafePointer - t reflect.Type // type of pointed-at value; may be nil - n uintptr // size of valid memory after p -} - -// pointerOf returns a pointer from v, which must be a reflect.Pointer. -func pointerOf(v reflect.Value) pointer { - assert(v.Kind() == reflect.Pointer, "got %v, want pointer", v.Kind()) - te := v.Type().Elem() - return pointer{unsafePointerOf(v), te, te.Size()} -} - -func (p pointer) pointerElem() pointer { - assert(p.t.Kind() == reflect.Pointer, "got %v, want pointer", p.t.Kind()) - te := p.t.Elem() - return pointer{p.unsafePointer.pointerElem(), te, te.Size()} -} - -func (p pointer) sliceLen() int { - assert(p.t.Kind() == reflect.Slice, "got %v, want slice", p.t.Kind()) - return p.unsafePointer.sliceLen() -} - -func (p pointer) sliceArray() pointer { - assert(p.t.Kind() == reflect.Slice, "got %v, want slice", p.t.Kind()) - n := p.sliceLen() - assert(n >= 0, "got negative slice length %d", n) - ta := reflect.ArrayOf(n, p.t.Elem()) - return pointer{p.unsafePointer.sliceArray(), ta, ta.Size()} -} - -func (p pointer) arrayIndex(index int, size uintptr) pointer { - assert(p.t.Kind() == reflect.Array, "got %v, want array", p.t.Kind()) - assert(0 <= index && index < p.t.Len(), "got array of size %d, want to access element %d", p.t.Len(), index) - assert(p.t.Elem().Size() == size, "got element size of %d, want %d", p.t.Elem().Size(), size) - te := p.t.Elem() - return pointer{p.unsafePointer.arrayIndex(index, size), te, te.Size()} -} - -func (p pointer) structField(index int, offset, size uintptr) pointer { - assert(p.t.Kind() == reflect.Struct, "got %v, want struct", p.t.Kind()) - assert(p.n >= offset, "got size of %d, want excessive start offset of %d", p.n, offset) - assert(p.n >= offset+size, "got size of %d, want excessive end offset of %d", p.n, offset+size) - if index < 0 { - return pointer{p.unsafePointer.structField(index, offset, size), nil, size} - } - sf := p.t.Field(index) - t := sf.Type - assert(sf.Offset == offset, "got offset of %d, want offset %d", sf.Offset, offset) - assert(t.Size() == size, "got size of %d, want size %d", t.Size(), size) - return pointer{p.unsafePointer.structField(index, offset, size), t, t.Size()} -} - -func (p pointer) asString() *string { - assert(p.t.Kind() == reflect.String, "got %v, want string", p.t) - return p.unsafePointer.asString() -} - -func (p pointer) asTime() *time.Time { - assert(p.t == timeTimeType, "got %v, want %v", p.t, timeTimeType) - return p.unsafePointer.asTime() -} - -func (p pointer) asAddr() *netip.Addr { - assert(p.t == netipAddrType, "got %v, want %v", p.t, netipAddrType) - return p.unsafePointer.asAddr() -} - -func (p pointer) asValue(typ reflect.Type) reflect.Value { - assert(p.t == typ, "got %v, want %v", p.t, typ) - return p.unsafePointer.asValue(typ) -} - -func (p pointer) asMemory(size uintptr) []byte { - assert(p.n >= size, "got size of %d, want excessive size of %d", p.n, size) - return p.unsafePointer.asMemory(size) -} - -func assert(b bool, f string, a ...any) { - if !b { - panic(fmt.Sprintf(f, a...)) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build race + +package deephash + +import ( + "fmt" + "net/netip" + "reflect" + "time" +) + +// pointer is a typed pointer that performs safety checks for every operation. +type pointer struct { + unsafePointer + t reflect.Type // type of pointed-at value; may be nil + n uintptr // size of valid memory after p +} + +// pointerOf returns a pointer from v, which must be a reflect.Pointer. +func pointerOf(v reflect.Value) pointer { + assert(v.Kind() == reflect.Pointer, "got %v, want pointer", v.Kind()) + te := v.Type().Elem() + return pointer{unsafePointerOf(v), te, te.Size()} +} + +func (p pointer) pointerElem() pointer { + assert(p.t.Kind() == reflect.Pointer, "got %v, want pointer", p.t.Kind()) + te := p.t.Elem() + return pointer{p.unsafePointer.pointerElem(), te, te.Size()} +} + +func (p pointer) sliceLen() int { + assert(p.t.Kind() == reflect.Slice, "got %v, want slice", p.t.Kind()) + return p.unsafePointer.sliceLen() +} + +func (p pointer) sliceArray() pointer { + assert(p.t.Kind() == reflect.Slice, "got %v, want slice", p.t.Kind()) + n := p.sliceLen() + assert(n >= 0, "got negative slice length %d", n) + ta := reflect.ArrayOf(n, p.t.Elem()) + return pointer{p.unsafePointer.sliceArray(), ta, ta.Size()} +} + +func (p pointer) arrayIndex(index int, size uintptr) pointer { + assert(p.t.Kind() == reflect.Array, "got %v, want array", p.t.Kind()) + assert(0 <= index && index < p.t.Len(), "got array of size %d, want to access element %d", p.t.Len(), index) + assert(p.t.Elem().Size() == size, "got element size of %d, want %d", p.t.Elem().Size(), size) + te := p.t.Elem() + return pointer{p.unsafePointer.arrayIndex(index, size), te, te.Size()} +} + +func (p pointer) structField(index int, offset, size uintptr) pointer { + assert(p.t.Kind() == reflect.Struct, "got %v, want struct", p.t.Kind()) + assert(p.n >= offset, "got size of %d, want excessive start offset of %d", p.n, offset) + assert(p.n >= offset+size, "got size of %d, want excessive end offset of %d", p.n, offset+size) + if index < 0 { + return pointer{p.unsafePointer.structField(index, offset, size), nil, size} + } + sf := p.t.Field(index) + t := sf.Type + assert(sf.Offset == offset, "got offset of %d, want offset %d", sf.Offset, offset) + assert(t.Size() == size, "got size of %d, want size %d", t.Size(), size) + return pointer{p.unsafePointer.structField(index, offset, size), t, t.Size()} +} + +func (p pointer) asString() *string { + assert(p.t.Kind() == reflect.String, "got %v, want string", p.t) + return p.unsafePointer.asString() +} + +func (p pointer) asTime() *time.Time { + assert(p.t == timeTimeType, "got %v, want %v", p.t, timeTimeType) + return p.unsafePointer.asTime() +} + +func (p pointer) asAddr() *netip.Addr { + assert(p.t == netipAddrType, "got %v, want %v", p.t, netipAddrType) + return p.unsafePointer.asAddr() +} + +func (p pointer) asValue(typ reflect.Type) reflect.Value { + assert(p.t == typ, "got %v, want %v", p.t, typ) + return p.unsafePointer.asValue(typ) +} + +func (p pointer) asMemory(size uintptr) []byte { + assert(p.n >= size, "got size of %d, want excessive size of %d", p.n, size) + return p.unsafePointer.asMemory(size) +} + +func assert(b bool, f string, a ...any) { + if !b { + panic(fmt.Sprintf(f, a...)) + } +} diff --git a/util/deephash/testtype/testtype.go b/util/deephash/testtype/testtype.go index 2df38da87..3c90053d6 100644 --- a/util/deephash/testtype/testtype.go +++ b/util/deephash/testtype/testtype.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package testtype contains types for testing deephash. -package testtype - -import "time" - -type UnexportedAddressableTime struct { - t time.Time -} - -func NewUnexportedAddressableTime(t time.Time) *UnexportedAddressableTime { - return &UnexportedAddressableTime{t: t} -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package testtype contains types for testing deephash. +package testtype + +import "time" + +type UnexportedAddressableTime struct { + t time.Time +} + +func NewUnexportedAddressableTime(t time.Time) *UnexportedAddressableTime { + return &UnexportedAddressableTime{t: t} +} diff --git a/util/dirwalk/dirwalk.go b/util/dirwalk/dirwalk.go index a05ee3553..811766892 100644 --- a/util/dirwalk/dirwalk.go +++ b/util/dirwalk/dirwalk.go @@ -1,53 +1,53 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package dirwalk contains code to walk a directory. -package dirwalk - -import ( - "io" - "io/fs" - "os" - - "go4.org/mem" -) - -var osWalkShallow func(name mem.RO, fn WalkFunc) error - -// WalkFunc is the callback type used with WalkShallow. -// -// The name and de are only valid for the duration of func's call -// and should not be retained. -type WalkFunc func(name mem.RO, de fs.DirEntry) error - -// WalkShallow reads the entries in the named directory and calls fn for each. -// It does not recurse into subdirectories. -// -// If fn returns an error, iteration stops and WalkShallow returns that value. -// -// On Linux, WalkShallow does not allocate, so long as certain methods on the -// WalkFunc's DirEntry are not called which necessarily allocate. -func WalkShallow(dirName mem.RO, fn WalkFunc) error { - if f := osWalkShallow; f != nil { - return f(dirName, fn) - } - of, err := os.Open(dirName.StringCopy()) - if err != nil { - return err - } - defer of.Close() - for { - fis, err := of.ReadDir(100) - for _, de := range fis { - if err := fn(mem.S(de.Name()), de); err != nil { - return err - } - } - if err != nil { - if err == io.EOF { - return nil - } - return err - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package dirwalk contains code to walk a directory. +package dirwalk + +import ( + "io" + "io/fs" + "os" + + "go4.org/mem" +) + +var osWalkShallow func(name mem.RO, fn WalkFunc) error + +// WalkFunc is the callback type used with WalkShallow. +// +// The name and de are only valid for the duration of func's call +// and should not be retained. +type WalkFunc func(name mem.RO, de fs.DirEntry) error + +// WalkShallow reads the entries in the named directory and calls fn for each. +// It does not recurse into subdirectories. +// +// If fn returns an error, iteration stops and WalkShallow returns that value. +// +// On Linux, WalkShallow does not allocate, so long as certain methods on the +// WalkFunc's DirEntry are not called which necessarily allocate. +func WalkShallow(dirName mem.RO, fn WalkFunc) error { + if f := osWalkShallow; f != nil { + return f(dirName, fn) + } + of, err := os.Open(dirName.StringCopy()) + if err != nil { + return err + } + defer of.Close() + for { + fis, err := of.ReadDir(100) + for _, de := range fis { + if err := fn(mem.S(de.Name()), de); err != nil { + return err + } + } + if err != nil { + if err == io.EOF { + return nil + } + return err + } + } +} diff --git a/util/dirwalk/dirwalk_linux.go b/util/dirwalk/dirwalk_linux.go index 714783145..256467ebd 100644 --- a/util/dirwalk/dirwalk_linux.go +++ b/util/dirwalk/dirwalk_linux.go @@ -1,167 +1,167 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dirwalk - -import ( - "fmt" - "io/fs" - "os" - "path/filepath" - "sync" - "syscall" - "unsafe" - - "go4.org/mem" - "golang.org/x/sys/unix" -) - -func init() { - osWalkShallow = linuxWalkShallow -} - -var dirEntPool = &sync.Pool{New: func() any { return new(linuxDirEnt) }} - -func linuxWalkShallow(dirName mem.RO, fn WalkFunc) error { - const blockSize = 8 << 10 - buf := make([]byte, blockSize) // stack-allocated; doesn't escape - - nameb := mem.Append(buf[:0], dirName) - nameb = append(nameb, 0) - - fd, err := sysOpen(nameb) - if err != nil { - return err - } - defer syscall.Close(fd) - - bufp := 0 // starting read position in buf - nbuf := 0 // end valid data in buf - - de := dirEntPool.Get().(*linuxDirEnt) - defer de.cleanAndPutInPool() - de.root = dirName - - for { - if bufp >= nbuf { - bufp = 0 - nbuf, err = readDirent(fd, buf) - if err != nil { - return err - } - if nbuf <= 0 { - return nil - } - } - consumed, name := parseDirEnt(&de.d, buf[bufp:nbuf]) - bufp += consumed - if len(name) == 0 || string(name) == "." || string(name) == ".." { - continue - } - de.name = mem.B(name) - if err := fn(de.name, de); err != nil { - return err - } - } -} - -type linuxDirEnt struct { - root mem.RO - d syscall.Dirent - name mem.RO -} - -func (de *linuxDirEnt) cleanAndPutInPool() { - de.root = mem.RO{} - de.name = mem.RO{} - dirEntPool.Put(de) -} - -func (de *linuxDirEnt) Name() string { return de.name.StringCopy() } -func (de *linuxDirEnt) Info() (fs.FileInfo, error) { - return os.Lstat(filepath.Join(de.root.StringCopy(), de.name.StringCopy())) -} -func (de *linuxDirEnt) IsDir() bool { - return de.d.Type == syscall.DT_DIR -} -func (de *linuxDirEnt) Type() fs.FileMode { - switch de.d.Type { - case syscall.DT_BLK: - return fs.ModeDevice // shrug - case syscall.DT_CHR: - return fs.ModeCharDevice - case syscall.DT_DIR: - return fs.ModeDir - case syscall.DT_FIFO: - return fs.ModeNamedPipe - case syscall.DT_LNK: - return fs.ModeSymlink - case syscall.DT_REG: - return 0 - case syscall.DT_SOCK: - return fs.ModeSocket - default: - return fs.ModeIrregular // shrug - } -} - -func direntNamlen(dirent *syscall.Dirent) int { - const fixedHdr = uint16(unsafe.Offsetof(syscall.Dirent{}.Name)) - limit := dirent.Reclen - fixedHdr - const dirNameLen = 256 // sizeof syscall.Dirent.Name - if limit > dirNameLen { - limit = dirNameLen - } - for i := uint16(0); i < limit; i++ { - if dirent.Name[i] == 0 { - return int(i) - } - } - panic("failed to find terminating 0 byte in dirent") -} - -func parseDirEnt(dirent *syscall.Dirent, buf []byte) (consumed int, name []byte) { - // golang.org/issue/37269 - copy(unsafe.Slice((*byte)(unsafe.Pointer(dirent)), unsafe.Sizeof(syscall.Dirent{})), buf) - if v := unsafe.Offsetof(dirent.Reclen) + unsafe.Sizeof(dirent.Reclen); uintptr(len(buf)) < v { - panic(fmt.Sprintf("buf size of %d smaller than dirent header size %d", len(buf), v)) - } - if len(buf) < int(dirent.Reclen) { - panic(fmt.Sprintf("buf size %d < record length %d", len(buf), dirent.Reclen)) - } - consumed = int(dirent.Reclen) - if dirent.Ino == 0 { // File absent in directory. - return - } - name = unsafe.Slice((*byte)(unsafe.Pointer(&dirent.Name[0])), direntNamlen(dirent)) - return -} - -func sysOpen(name []byte) (fd int, err error) { - if len(name) == 0 || name[len(name)-1] != 0 { - return 0, syscall.EINVAL - } - var dirfd int = unix.AT_FDCWD - for { - r0, _, e1 := syscall.Syscall(unix.SYS_OPENAT, uintptr(dirfd), - uintptr(unsafe.Pointer(&name[0])), 0) - if e1 == 0 { - return int(r0), nil - } - if e1 == syscall.EINTR { - // Since https://golang.org/doc/go1.14#runtime we - // need to loop on EINTR on more places. - continue - } - return 0, syscall.Errno(e1) - } -} - -func readDirent(fd int, buf []byte) (n int, err error) { - for { - nbuf, err := syscall.ReadDirent(fd, buf) - if err != syscall.EINTR { - return nbuf, err - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dirwalk + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + "sync" + "syscall" + "unsafe" + + "go4.org/mem" + "golang.org/x/sys/unix" +) + +func init() { + osWalkShallow = linuxWalkShallow +} + +var dirEntPool = &sync.Pool{New: func() any { return new(linuxDirEnt) }} + +func linuxWalkShallow(dirName mem.RO, fn WalkFunc) error { + const blockSize = 8 << 10 + buf := make([]byte, blockSize) // stack-allocated; doesn't escape + + nameb := mem.Append(buf[:0], dirName) + nameb = append(nameb, 0) + + fd, err := sysOpen(nameb) + if err != nil { + return err + } + defer syscall.Close(fd) + + bufp := 0 // starting read position in buf + nbuf := 0 // end valid data in buf + + de := dirEntPool.Get().(*linuxDirEnt) + defer de.cleanAndPutInPool() + de.root = dirName + + for { + if bufp >= nbuf { + bufp = 0 + nbuf, err = readDirent(fd, buf) + if err != nil { + return err + } + if nbuf <= 0 { + return nil + } + } + consumed, name := parseDirEnt(&de.d, buf[bufp:nbuf]) + bufp += consumed + if len(name) == 0 || string(name) == "." || string(name) == ".." { + continue + } + de.name = mem.B(name) + if err := fn(de.name, de); err != nil { + return err + } + } +} + +type linuxDirEnt struct { + root mem.RO + d syscall.Dirent + name mem.RO +} + +func (de *linuxDirEnt) cleanAndPutInPool() { + de.root = mem.RO{} + de.name = mem.RO{} + dirEntPool.Put(de) +} + +func (de *linuxDirEnt) Name() string { return de.name.StringCopy() } +func (de *linuxDirEnt) Info() (fs.FileInfo, error) { + return os.Lstat(filepath.Join(de.root.StringCopy(), de.name.StringCopy())) +} +func (de *linuxDirEnt) IsDir() bool { + return de.d.Type == syscall.DT_DIR +} +func (de *linuxDirEnt) Type() fs.FileMode { + switch de.d.Type { + case syscall.DT_BLK: + return fs.ModeDevice // shrug + case syscall.DT_CHR: + return fs.ModeCharDevice + case syscall.DT_DIR: + return fs.ModeDir + case syscall.DT_FIFO: + return fs.ModeNamedPipe + case syscall.DT_LNK: + return fs.ModeSymlink + case syscall.DT_REG: + return 0 + case syscall.DT_SOCK: + return fs.ModeSocket + default: + return fs.ModeIrregular // shrug + } +} + +func direntNamlen(dirent *syscall.Dirent) int { + const fixedHdr = uint16(unsafe.Offsetof(syscall.Dirent{}.Name)) + limit := dirent.Reclen - fixedHdr + const dirNameLen = 256 // sizeof syscall.Dirent.Name + if limit > dirNameLen { + limit = dirNameLen + } + for i := uint16(0); i < limit; i++ { + if dirent.Name[i] == 0 { + return int(i) + } + } + panic("failed to find terminating 0 byte in dirent") +} + +func parseDirEnt(dirent *syscall.Dirent, buf []byte) (consumed int, name []byte) { + // golang.org/issue/37269 + copy(unsafe.Slice((*byte)(unsafe.Pointer(dirent)), unsafe.Sizeof(syscall.Dirent{})), buf) + if v := unsafe.Offsetof(dirent.Reclen) + unsafe.Sizeof(dirent.Reclen); uintptr(len(buf)) < v { + panic(fmt.Sprintf("buf size of %d smaller than dirent header size %d", len(buf), v)) + } + if len(buf) < int(dirent.Reclen) { + panic(fmt.Sprintf("buf size %d < record length %d", len(buf), dirent.Reclen)) + } + consumed = int(dirent.Reclen) + if dirent.Ino == 0 { // File absent in directory. + return + } + name = unsafe.Slice((*byte)(unsafe.Pointer(&dirent.Name[0])), direntNamlen(dirent)) + return +} + +func sysOpen(name []byte) (fd int, err error) { + if len(name) == 0 || name[len(name)-1] != 0 { + return 0, syscall.EINVAL + } + var dirfd int = unix.AT_FDCWD + for { + r0, _, e1 := syscall.Syscall(unix.SYS_OPENAT, uintptr(dirfd), + uintptr(unsafe.Pointer(&name[0])), 0) + if e1 == 0 { + return int(r0), nil + } + if e1 == syscall.EINTR { + // Since https://golang.org/doc/go1.14#runtime we + // need to loop on EINTR on more places. + continue + } + return 0, syscall.Errno(e1) + } +} + +func readDirent(fd int, buf []byte) (n int, err error) { + for { + nbuf, err := syscall.ReadDirent(fd, buf) + if err != syscall.EINTR { + return nbuf, err + } + } +} diff --git a/util/dirwalk/dirwalk_test.go b/util/dirwalk/dirwalk_test.go index e2e41f634..15ebc13dd 100644 --- a/util/dirwalk/dirwalk_test.go +++ b/util/dirwalk/dirwalk_test.go @@ -1,91 +1,91 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dirwalk - -import ( - "fmt" - "os" - "path/filepath" - "reflect" - "runtime" - "sort" - "testing" - - "go4.org/mem" - "tailscale.com/tstest" -) - -func TestWalkShallowOSSpecific(t *testing.T) { - if osWalkShallow == nil { - t.Skip("no OS-specific implementation") - } - testWalkShallow(t, false) -} - -func TestWalkShallowPortable(t *testing.T) { - testWalkShallow(t, true) -} - -func testWalkShallow(t *testing.T, portable bool) { - if portable { - tstest.Replace(t, &osWalkShallow, nil) - } - d := t.TempDir() - - t.Run("basics", func(t *testing.T) { - if err := os.WriteFile(filepath.Join(d, "foo"), []byte("1"), 0600); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(filepath.Join(d, "bar"), []byte("22"), 0400); err != nil { - t.Fatal(err) - } - if err := os.Mkdir(filepath.Join(d, "baz"), 0777); err != nil { - t.Fatal(err) - } - - var got []string - if err := WalkShallow(mem.S(d), func(name mem.RO, de os.DirEntry) error { - var size int64 - if fi, err := de.Info(); err != nil { - t.Errorf("Info stat error on %q: %v", de.Name(), err) - } else if !fi.IsDir() { - size = fi.Size() - } - got = append(got, fmt.Sprintf("%q %q dir=%v type=%d size=%v", name.StringCopy(), de.Name(), de.IsDir(), de.Type(), size)) - return nil - }); err != nil { - t.Fatal(err) - } - sort.Strings(got) - want := []string{ - `"bar" "bar" dir=false type=0 size=2`, - `"baz" "baz" dir=true type=2147483648 size=0`, - `"foo" "foo" dir=false type=0 size=1`, - } - if !reflect.DeepEqual(got, want) { - t.Errorf("mismatch:\n got %#q\nwant %#q", got, want) - } - }) - - t.Run("err_not_exist", func(t *testing.T) { - err := WalkShallow(mem.S(filepath.Join(d, "not_exist")), func(name mem.RO, de os.DirEntry) error { - return nil - }) - if !os.IsNotExist(err) { - t.Errorf("unexpected error: %v", err) - } - }) - - t.Run("allocs", func(t *testing.T) { - allocs := int(testing.AllocsPerRun(1000, func() { - if err := WalkShallow(mem.S(d), func(name mem.RO, de os.DirEntry) error { return nil }); err != nil { - t.Fatal(err) - } - })) - t.Logf("allocs = %v", allocs) - if !portable && runtime.GOOS == "linux" && allocs != 0 { - t.Errorf("unexpected allocs: got %v, want 0", allocs) - } - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dirwalk + +import ( + "fmt" + "os" + "path/filepath" + "reflect" + "runtime" + "sort" + "testing" + + "go4.org/mem" + "tailscale.com/tstest" +) + +func TestWalkShallowOSSpecific(t *testing.T) { + if osWalkShallow == nil { + t.Skip("no OS-specific implementation") + } + testWalkShallow(t, false) +} + +func TestWalkShallowPortable(t *testing.T) { + testWalkShallow(t, true) +} + +func testWalkShallow(t *testing.T, portable bool) { + if portable { + tstest.Replace(t, &osWalkShallow, nil) + } + d := t.TempDir() + + t.Run("basics", func(t *testing.T) { + if err := os.WriteFile(filepath.Join(d, "foo"), []byte("1"), 0600); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(d, "bar"), []byte("22"), 0400); err != nil { + t.Fatal(err) + } + if err := os.Mkdir(filepath.Join(d, "baz"), 0777); err != nil { + t.Fatal(err) + } + + var got []string + if err := WalkShallow(mem.S(d), func(name mem.RO, de os.DirEntry) error { + var size int64 + if fi, err := de.Info(); err != nil { + t.Errorf("Info stat error on %q: %v", de.Name(), err) + } else if !fi.IsDir() { + size = fi.Size() + } + got = append(got, fmt.Sprintf("%q %q dir=%v type=%d size=%v", name.StringCopy(), de.Name(), de.IsDir(), de.Type(), size)) + return nil + }); err != nil { + t.Fatal(err) + } + sort.Strings(got) + want := []string{ + `"bar" "bar" dir=false type=0 size=2`, + `"baz" "baz" dir=true type=2147483648 size=0`, + `"foo" "foo" dir=false type=0 size=1`, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch:\n got %#q\nwant %#q", got, want) + } + }) + + t.Run("err_not_exist", func(t *testing.T) { + err := WalkShallow(mem.S(filepath.Join(d, "not_exist")), func(name mem.RO, de os.DirEntry) error { + return nil + }) + if !os.IsNotExist(err) { + t.Errorf("unexpected error: %v", err) + } + }) + + t.Run("allocs", func(t *testing.T) { + allocs := int(testing.AllocsPerRun(1000, func() { + if err := WalkShallow(mem.S(d), func(name mem.RO, de os.DirEntry) error { return nil }); err != nil { + t.Fatal(err) + } + })) + t.Logf("allocs = %v", allocs) + if !portable && runtime.GOOS == "linux" && allocs != 0 { + t.Errorf("unexpected allocs: got %v, want 0", allocs) + } + }) +} diff --git a/util/goroutines/goroutines.go b/util/goroutines/goroutines.go index 24c61b37c..9758b0758 100644 --- a/util/goroutines/goroutines.go +++ b/util/goroutines/goroutines.go @@ -1,93 +1,93 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// The goroutines package contains utilities for getting active goroutines. -package goroutines - -import ( - "bytes" - "fmt" - "runtime" - "strconv" -) - -// ScrubbedGoroutineDump returns either the current goroutine's stack or all -// goroutines' stacks, but with the actual values of arguments scrubbed out, -// lest it contain some private key material. -func ScrubbedGoroutineDump(all bool) []byte { - var buf []byte - // Grab stacks multiple times into increasingly larger buffer sizes - // to minimize the risk that we blow past our iOS memory limit. - for size := 1 << 10; size <= 1<<20; size += 1 << 10 { - buf = make([]byte, size) - buf = buf[:runtime.Stack(buf, all)] - if len(buf) < size { - // It fit. - break - } - } - return scrubHex(buf) -} - -func scrubHex(buf []byte) []byte { - saw := map[string][]byte{} // "0x123" => "v1%3" (unique value 1 and its value mod 8) - - foreachHexAddress(buf, func(in []byte) { - if string(in) == "0x0" { - return - } - if v, ok := saw[string(in)]; ok { - for i := range in { - in[i] = '_' - } - copy(in, v) - return - } - inStr := string(in) - u64, err := strconv.ParseUint(string(in[2:]), 16, 64) - for i := range in { - in[i] = '_' - } - if err != nil { - in[0] = '?' - return - } - v := []byte(fmt.Sprintf("v%d%%%d", len(saw)+1, u64%8)) - saw[inStr] = v - copy(in, v) - }) - return buf -} - -var ohx = []byte("0x") - -// foreachHexAddress calls f with each subslice of b that matches -// regexp `0x[0-9a-f]*`. -func foreachHexAddress(b []byte, f func([]byte)) { - for len(b) > 0 { - i := bytes.Index(b, ohx) - if i == -1 { - return - } - b = b[i:] - hx := hexPrefix(b) - f(hx) - b = b[len(hx):] - } -} - -func hexPrefix(b []byte) []byte { - for i, c := range b { - if i < 2 { - continue - } - if !isHexByte(c) { - return b[:i] - } - } - return b -} - -func isHexByte(b byte) bool { - return '0' <= b && b <= '9' || 'a' <= b && b <= 'f' || 'A' <= b && b <= 'F' -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The goroutines package contains utilities for getting active goroutines. +package goroutines + +import ( + "bytes" + "fmt" + "runtime" + "strconv" +) + +// ScrubbedGoroutineDump returns either the current goroutine's stack or all +// goroutines' stacks, but with the actual values of arguments scrubbed out, +// lest it contain some private key material. +func ScrubbedGoroutineDump(all bool) []byte { + var buf []byte + // Grab stacks multiple times into increasingly larger buffer sizes + // to minimize the risk that we blow past our iOS memory limit. + for size := 1 << 10; size <= 1<<20; size += 1 << 10 { + buf = make([]byte, size) + buf = buf[:runtime.Stack(buf, all)] + if len(buf) < size { + // It fit. + break + } + } + return scrubHex(buf) +} + +func scrubHex(buf []byte) []byte { + saw := map[string][]byte{} // "0x123" => "v1%3" (unique value 1 and its value mod 8) + + foreachHexAddress(buf, func(in []byte) { + if string(in) == "0x0" { + return + } + if v, ok := saw[string(in)]; ok { + for i := range in { + in[i] = '_' + } + copy(in, v) + return + } + inStr := string(in) + u64, err := strconv.ParseUint(string(in[2:]), 16, 64) + for i := range in { + in[i] = '_' + } + if err != nil { + in[0] = '?' + return + } + v := []byte(fmt.Sprintf("v%d%%%d", len(saw)+1, u64%8)) + saw[inStr] = v + copy(in, v) + }) + return buf +} + +var ohx = []byte("0x") + +// foreachHexAddress calls f with each subslice of b that matches +// regexp `0x[0-9a-f]*`. +func foreachHexAddress(b []byte, f func([]byte)) { + for len(b) > 0 { + i := bytes.Index(b, ohx) + if i == -1 { + return + } + b = b[i:] + hx := hexPrefix(b) + f(hx) + b = b[len(hx):] + } +} + +func hexPrefix(b []byte) []byte { + for i, c := range b { + if i < 2 { + continue + } + if !isHexByte(c) { + return b[:i] + } + } + return b +} + +func isHexByte(b byte) bool { + return '0' <= b && b <= '9' || 'a' <= b && b <= 'f' || 'A' <= b && b <= 'F' +} diff --git a/util/goroutines/goroutines_test.go b/util/goroutines/goroutines_test.go index df6560fe5..ae17c399c 100644 --- a/util/goroutines/goroutines_test.go +++ b/util/goroutines/goroutines_test.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package goroutines - -import "testing" - -func TestScrubbedGoroutineDump(t *testing.T) { - t.Logf("Got:\n%s\n", ScrubbedGoroutineDump(true)) -} - -func TestScrubHex(t *testing.T) { - tests := []struct { - in, want string - }{ - {"foo", "foo"}, - {"", ""}, - {"0x", "?_"}, - {"0x001 and same 0x001", "v1%1_ and same v1%1_"}, - {"0x008 and same 0x008", "v1%0_ and same v1%0_"}, - {"0x001 and diff 0x002", "v1%1_ and diff v2%2_"}, - } - for _, tt := range tests { - got := scrubHex([]byte(tt.in)) - if string(got) != tt.want { - t.Errorf("for input:\n%s\n\ngot:\n%s\n\nwant:\n%s\n", tt.in, got, tt.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package goroutines + +import "testing" + +func TestScrubbedGoroutineDump(t *testing.T) { + t.Logf("Got:\n%s\n", ScrubbedGoroutineDump(true)) +} + +func TestScrubHex(t *testing.T) { + tests := []struct { + in, want string + }{ + {"foo", "foo"}, + {"", ""}, + {"0x", "?_"}, + {"0x001 and same 0x001", "v1%1_ and same v1%1_"}, + {"0x008 and same 0x008", "v1%0_ and same v1%0_"}, + {"0x001 and diff 0x002", "v1%1_ and diff v2%2_"}, + } + for _, tt := range tests { + got := scrubHex([]byte(tt.in)) + if string(got) != tt.want { + t.Errorf("for input:\n%s\n\ngot:\n%s\n\nwant:\n%s\n", tt.in, got, tt.want) + } + } +} diff --git a/util/groupmember/groupmember.go b/util/groupmember/groupmember.go index 38431a7ff..d60416816 100644 --- a/util/groupmember/groupmember.go +++ b/util/groupmember/groupmember.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package groupmember verifies group membership of the provided user on the -// local system. -package groupmember - -import ( - "os/user" - "slices" -) - -// IsMemberOfGroup reports whether the provided user is a member of -// the provided system group. -func IsMemberOfGroup(group, userName string) (bool, error) { - u, err := user.Lookup(userName) - if err != nil { - return false, err - } - g, err := user.LookupGroup(group) - if err != nil { - return false, err - } - ugids, err := u.GroupIds() - if err != nil { - return false, err - } - return slices.Contains(ugids, g.Gid), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package groupmember verifies group membership of the provided user on the +// local system. +package groupmember + +import ( + "os/user" + "slices" +) + +// IsMemberOfGroup reports whether the provided user is a member of +// the provided system group. +func IsMemberOfGroup(group, userName string) (bool, error) { + u, err := user.Lookup(userName) + if err != nil { + return false, err + } + g, err := user.LookupGroup(group) + if err != nil { + return false, err + } + ugids, err := u.GroupIds() + if err != nil { + return false, err + } + return slices.Contains(ugids, g.Gid), nil +} diff --git a/util/hashx/block512.go b/util/hashx/block512.go index dd69ccd35..e637c0c03 100644 --- a/util/hashx/block512.go +++ b/util/hashx/block512.go @@ -1,197 +1,197 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package hashx provides a concrete implementation of [hash.Hash] -// that operates on a particular block size. -package hashx - -import ( - "encoding/binary" - "fmt" - "hash" - "unsafe" -) - -var _ hash.Hash = (*Block512)(nil) - -// Block512 wraps a [hash.Hash] for functions that operate on 512-bit block sizes. -// It has efficient methods for hashing fixed-width integers. -// -// A hashing algorithm that operates on 512-bit block sizes should be used. -// The hash still operates correctly even with misaligned block sizes, -// but operates less efficiently. -// -// Example algorithms with 512-bit block sizes include: -// - MD4 (https://golang.org/x/crypto/md4) -// - MD5 (https://golang.org/pkg/crypto/md5) -// - BLAKE2s (https://golang.org/x/crypto/blake2s) -// - BLAKE3 -// - RIPEMD (https://golang.org/x/crypto/ripemd160) -// - SHA-0 -// - SHA-1 (https://golang.org/pkg/crypto/sha1) -// - SHA-2 (https://golang.org/pkg/crypto/sha256) -// - Whirlpool -// -// See https://en.wikipedia.org/wiki/Comparison_of_cryptographic_hash_functions#Parameters -// for a list of hash functions and their block sizes. -// -// Block512 assumes that [hash.Hash.Write] never fails and -// never allows the provided buffer to escape. -type Block512 struct { - hash.Hash - - x [512 / 8]byte - nx int -} - -// New512 constructs a new Block512 that wraps h. -// -// It reports an error if the block sizes do not match. -// Misaligned block sizes perform poorly, but execute correctly. -// The error may be ignored if performance is not a concern. -func New512(h hash.Hash) (*Block512, error) { - b := &Block512{Hash: h} - if len(b.x)%h.BlockSize() != 0 { - return b, fmt.Errorf("hashx.Block512: inefficient use of hash.Hash with %d-bit block size", 8*h.BlockSize()) - } - return b, nil -} - -// Write hashes the contents of b. -func (h *Block512) Write(b []byte) (int, error) { - h.HashBytes(b) - return len(b), nil -} - -// Sum appends the current hash to b and returns the resulting slice. -// -// It flushes any partially completed blocks to the underlying [hash.Hash], -// which may cause future operations to be misaligned and less efficient -// until [Block512.Reset] is called. -func (h *Block512) Sum(b []byte) []byte { - if h.nx > 0 { - h.Hash.Write(h.x[:h.nx]) - h.nx = 0 - } - - // Unfortunately hash.Hash.Sum always causes the input to escape since - // escape analysis cannot prove anything past an interface method call. - // Assuming h already escapes, we call Sum with h.x first, - // and then copy the result to b. - sum := h.Hash.Sum(h.x[:0]) - return append(b, sum...) -} - -// Reset resets Block512 to its initial state. -// It recursively resets the underlying [hash.Hash]. -func (h *Block512) Reset() { - h.Hash.Reset() - h.nx = 0 -} - -// HashUint8 hashes n as a 1-byte integer. -func (h *Block512) HashUint8(n uint8) { - // NOTE: This method is carefully written to be inlineable. - if h.nx <= len(h.x)-1 { - h.x[h.nx] = n - h.nx += 1 - } else { - h.hashUint8Slow(n) // mark "noinline" to keep this within inline budget - } -} - -//go:noinline -func (h *Block512) hashUint8Slow(n uint8) { h.hashUint(uint64(n), 1) } - -// HashUint16 hashes n as a 2-byte little-endian integer. -func (h *Block512) HashUint16(n uint16) { - // NOTE: This method is carefully written to be inlineable. - if h.nx <= len(h.x)-2 { - binary.LittleEndian.PutUint16(h.x[h.nx:], n) - h.nx += 2 - } else { - h.hashUint16Slow(n) // mark "noinline" to keep this within inline budget - } -} - -//go:noinline -func (h *Block512) hashUint16Slow(n uint16) { h.hashUint(uint64(n), 2) } - -// HashUint32 hashes n as a 4-byte little-endian integer. -func (h *Block512) HashUint32(n uint32) { - // NOTE: This method is carefully written to be inlineable. - if h.nx <= len(h.x)-4 { - binary.LittleEndian.PutUint32(h.x[h.nx:], n) - h.nx += 4 - } else { - h.hashUint32Slow(n) // mark "noinline" to keep this within inline budget - } -} - -//go:noinline -func (h *Block512) hashUint32Slow(n uint32) { h.hashUint(uint64(n), 4) } - -// HashUint64 hashes n as a 8-byte little-endian integer. -func (h *Block512) HashUint64(n uint64) { - // NOTE: This method is carefully written to be inlineable. - if h.nx <= len(h.x)-8 { - binary.LittleEndian.PutUint64(h.x[h.nx:], n) - h.nx += 8 - } else { - h.hashUint64Slow(n) // mark "noinline" to keep this within inline budget - } -} - -//go:noinline -func (h *Block512) hashUint64Slow(n uint64) { h.hashUint(uint64(n), 8) } - -func (h *Block512) hashUint(n uint64, i int) { - for ; i > 0; i-- { - if h.nx == len(h.x) { - h.Hash.Write(h.x[:]) - h.nx = 0 - } - h.x[h.nx] = byte(n) - h.nx += 1 - n >>= 8 - } -} - -// HashBytes hashes the contents of b. -// It does not explicitly hash the length separately. -func (h *Block512) HashBytes(b []byte) { - // Nearly identical to sha256.digest.Write. - if h.nx > 0 { - n := copy(h.x[h.nx:], b) - h.nx += n - if h.nx == len(h.x) { - h.Hash.Write(h.x[:]) - h.nx = 0 - } - b = b[n:] - } - if len(b) >= len(h.x) { - n := len(b) &^ (len(h.x) - 1) // n is a multiple of len(h.x) - h.Hash.Write(b[:n]) - b = b[n:] - } - if len(b) > 0 { - h.nx = copy(h.x[:], b) - } -} - -// HashString hashes the contents of s. -// It does not explicitly hash the length separately. -func (h *Block512) HashString(s string) { - // TODO: Avoid unsafe when standard hashers implement io.StringWriter. - // See https://go.dev/issue/38776. - type stringHeader struct { - p unsafe.Pointer - n int - } - p := (*stringHeader)(unsafe.Pointer(&s)) - b := unsafe.Slice((*byte)(p.p), p.n) - h.HashBytes(b) -} - -// TODO: Add Hash.MarshalBinary and Hash.UnmarshalBinary? +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package hashx provides a concrete implementation of [hash.Hash] +// that operates on a particular block size. +package hashx + +import ( + "encoding/binary" + "fmt" + "hash" + "unsafe" +) + +var _ hash.Hash = (*Block512)(nil) + +// Block512 wraps a [hash.Hash] for functions that operate on 512-bit block sizes. +// It has efficient methods for hashing fixed-width integers. +// +// A hashing algorithm that operates on 512-bit block sizes should be used. +// The hash still operates correctly even with misaligned block sizes, +// but operates less efficiently. +// +// Example algorithms with 512-bit block sizes include: +// - MD4 (https://golang.org/x/crypto/md4) +// - MD5 (https://golang.org/pkg/crypto/md5) +// - BLAKE2s (https://golang.org/x/crypto/blake2s) +// - BLAKE3 +// - RIPEMD (https://golang.org/x/crypto/ripemd160) +// - SHA-0 +// - SHA-1 (https://golang.org/pkg/crypto/sha1) +// - SHA-2 (https://golang.org/pkg/crypto/sha256) +// - Whirlpool +// +// See https://en.wikipedia.org/wiki/Comparison_of_cryptographic_hash_functions#Parameters +// for a list of hash functions and their block sizes. +// +// Block512 assumes that [hash.Hash.Write] never fails and +// never allows the provided buffer to escape. +type Block512 struct { + hash.Hash + + x [512 / 8]byte + nx int +} + +// New512 constructs a new Block512 that wraps h. +// +// It reports an error if the block sizes do not match. +// Misaligned block sizes perform poorly, but execute correctly. +// The error may be ignored if performance is not a concern. +func New512(h hash.Hash) (*Block512, error) { + b := &Block512{Hash: h} + if len(b.x)%h.BlockSize() != 0 { + return b, fmt.Errorf("hashx.Block512: inefficient use of hash.Hash with %d-bit block size", 8*h.BlockSize()) + } + return b, nil +} + +// Write hashes the contents of b. +func (h *Block512) Write(b []byte) (int, error) { + h.HashBytes(b) + return len(b), nil +} + +// Sum appends the current hash to b and returns the resulting slice. +// +// It flushes any partially completed blocks to the underlying [hash.Hash], +// which may cause future operations to be misaligned and less efficient +// until [Block512.Reset] is called. +func (h *Block512) Sum(b []byte) []byte { + if h.nx > 0 { + h.Hash.Write(h.x[:h.nx]) + h.nx = 0 + } + + // Unfortunately hash.Hash.Sum always causes the input to escape since + // escape analysis cannot prove anything past an interface method call. + // Assuming h already escapes, we call Sum with h.x first, + // and then copy the result to b. + sum := h.Hash.Sum(h.x[:0]) + return append(b, sum...) +} + +// Reset resets Block512 to its initial state. +// It recursively resets the underlying [hash.Hash]. +func (h *Block512) Reset() { + h.Hash.Reset() + h.nx = 0 +} + +// HashUint8 hashes n as a 1-byte integer. +func (h *Block512) HashUint8(n uint8) { + // NOTE: This method is carefully written to be inlineable. + if h.nx <= len(h.x)-1 { + h.x[h.nx] = n + h.nx += 1 + } else { + h.hashUint8Slow(n) // mark "noinline" to keep this within inline budget + } +} + +//go:noinline +func (h *Block512) hashUint8Slow(n uint8) { h.hashUint(uint64(n), 1) } + +// HashUint16 hashes n as a 2-byte little-endian integer. +func (h *Block512) HashUint16(n uint16) { + // NOTE: This method is carefully written to be inlineable. + if h.nx <= len(h.x)-2 { + binary.LittleEndian.PutUint16(h.x[h.nx:], n) + h.nx += 2 + } else { + h.hashUint16Slow(n) // mark "noinline" to keep this within inline budget + } +} + +//go:noinline +func (h *Block512) hashUint16Slow(n uint16) { h.hashUint(uint64(n), 2) } + +// HashUint32 hashes n as a 4-byte little-endian integer. +func (h *Block512) HashUint32(n uint32) { + // NOTE: This method is carefully written to be inlineable. + if h.nx <= len(h.x)-4 { + binary.LittleEndian.PutUint32(h.x[h.nx:], n) + h.nx += 4 + } else { + h.hashUint32Slow(n) // mark "noinline" to keep this within inline budget + } +} + +//go:noinline +func (h *Block512) hashUint32Slow(n uint32) { h.hashUint(uint64(n), 4) } + +// HashUint64 hashes n as a 8-byte little-endian integer. +func (h *Block512) HashUint64(n uint64) { + // NOTE: This method is carefully written to be inlineable. + if h.nx <= len(h.x)-8 { + binary.LittleEndian.PutUint64(h.x[h.nx:], n) + h.nx += 8 + } else { + h.hashUint64Slow(n) // mark "noinline" to keep this within inline budget + } +} + +//go:noinline +func (h *Block512) hashUint64Slow(n uint64) { h.hashUint(uint64(n), 8) } + +func (h *Block512) hashUint(n uint64, i int) { + for ; i > 0; i-- { + if h.nx == len(h.x) { + h.Hash.Write(h.x[:]) + h.nx = 0 + } + h.x[h.nx] = byte(n) + h.nx += 1 + n >>= 8 + } +} + +// HashBytes hashes the contents of b. +// It does not explicitly hash the length separately. +func (h *Block512) HashBytes(b []byte) { + // Nearly identical to sha256.digest.Write. + if h.nx > 0 { + n := copy(h.x[h.nx:], b) + h.nx += n + if h.nx == len(h.x) { + h.Hash.Write(h.x[:]) + h.nx = 0 + } + b = b[n:] + } + if len(b) >= len(h.x) { + n := len(b) &^ (len(h.x) - 1) // n is a multiple of len(h.x) + h.Hash.Write(b[:n]) + b = b[n:] + } + if len(b) > 0 { + h.nx = copy(h.x[:], b) + } +} + +// HashString hashes the contents of s. +// It does not explicitly hash the length separately. +func (h *Block512) HashString(s string) { + // TODO: Avoid unsafe when standard hashers implement io.StringWriter. + // See https://go.dev/issue/38776. + type stringHeader struct { + p unsafe.Pointer + n int + } + p := (*stringHeader)(unsafe.Pointer(&s)) + b := unsafe.Slice((*byte)(p.p), p.n) + h.HashBytes(b) +} + +// TODO: Add Hash.MarshalBinary and Hash.UnmarshalBinary? diff --git a/util/httphdr/httphdr.go b/util/httphdr/httphdr.go index b78b165c6..852e28b8f 100644 --- a/util/httphdr/httphdr.go +++ b/util/httphdr/httphdr.go @@ -1,197 +1,197 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package httphdr implements functionality for parsing and formatting -// standard HTTP headers. -package httphdr - -import ( - "bytes" - "strconv" - "strings" -) - -// Range is a range of bytes within some content. -type Range struct { - // Start is the starting offset. - // It is zero if Length is negative; it must not be negative. - Start int64 - // Length is the length of the content. - // It is zero if the length extends to the end of the content. - // It is negative if the length is relative to the end (e.g., last 5 bytes). - Length int64 -} - -// ows is optional whitespace. -const ows = " \t" // per RFC 7230, section 3.2.3 - -// ParseRange parses a "Range" header per RFC 7233, section 3. -// It only handles "Range" headers where the units is "bytes". -// The "Range" header is usually only specified in GET requests. -func ParseRange(hdr string) (ranges []Range, ok bool) { - // Grammar per RFC 7233, appendix D: - // Range = byte-ranges-specifier | other-ranges-specifier - // byte-ranges-specifier = bytes-unit "=" byte-range-set - // bytes-unit = "bytes" - // byte-range-set = - // *("," OWS) - // (byte-range-spec | suffix-byte-range-spec) - // *(OWS "," [OWS ( byte-range-spec | suffix-byte-range-spec )]) - // byte-range-spec = first-byte-pos "-" [last-byte-pos] - // suffix-byte-range-spec = "-" suffix-length - // We do not support other-ranges-specifier. - // All other identifiers are 1*DIGIT. - hdr = strings.Trim(hdr, ows) // per RFC 7230, section 3.2 - units, elems, hasUnits := strings.Cut(hdr, "=") - elems = strings.TrimLeft(elems, ","+ows) - for _, elem := range strings.Split(elems, ",") { - elem = strings.Trim(elem, ows) // per RFC 7230, section 7 - switch { - case strings.HasPrefix(elem, "-"): // i.e., "-" suffix-length - n, ok := parseNumber(strings.TrimPrefix(elem, "-")) - if !ok { - return ranges, false - } - ranges = append(ranges, Range{0, -n}) - case strings.HasSuffix(elem, "-"): // i.e., first-byte-pos "-" - n, ok := parseNumber(strings.TrimSuffix(elem, "-")) - if !ok { - return ranges, false - } - ranges = append(ranges, Range{n, 0}) - default: // i.e., first-byte-pos "-" last-byte-pos - prefix, suffix, hasDash := strings.Cut(elem, "-") - n, ok2 := parseNumber(prefix) - m, ok3 := parseNumber(suffix) - if !hasDash || !ok2 || !ok3 || m < n { - return ranges, false - } - ranges = append(ranges, Range{n, m - n + 1}) - } - } - return ranges, units == "bytes" && hasUnits && len(ranges) > 0 // must see at least one element per RFC 7233, section 2.1 -} - -// FormatRange formats a "Range" header per RFC 7233, section 3. -// It only handles "Range" headers where the units is "bytes". -// The "Range" header is usually only specified in GET requests. -func FormatRange(ranges []Range) (hdr string, ok bool) { - b := []byte("bytes=") - for _, r := range ranges { - switch { - case r.Length > 0: // i.e., first-byte-pos "-" last-byte-pos - if r.Start < 0 { - return string(b), false - } - b = strconv.AppendUint(b, uint64(r.Start), 10) - b = append(b, '-') - b = strconv.AppendUint(b, uint64(r.Start+r.Length-1), 10) - b = append(b, ',') - case r.Length == 0: // i.e., first-byte-pos "-" - if r.Start < 0 { - return string(b), false - } - b = strconv.AppendUint(b, uint64(r.Start), 10) - b = append(b, '-') - b = append(b, ',') - case r.Length < 0: // i.e., "-" suffix-length - if r.Start != 0 { - return string(b), false - } - b = append(b, '-') - b = strconv.AppendUint(b, uint64(-r.Length), 10) - b = append(b, ',') - default: - return string(b), false - } - } - return string(bytes.TrimRight(b, ",")), len(ranges) > 0 -} - -// ParseContentRange parses a "Content-Range" header per RFC 7233, section 4.2. -// It only handles "Content-Range" headers where the units is "bytes". -// The "Content-Range" header is usually only specified in HTTP responses. -// -// If only the completeLength is specified, then start and length are both zero. -// -// Otherwise, the parses the start and length and the optional completeLength, -// which is -1 if unspecified. The start is non-negative and the length is positive. -func ParseContentRange(hdr string) (start, length, completeLength int64, ok bool) { - // Grammar per RFC 7233, appendix D: - // Content-Range = byte-content-range | other-content-range - // byte-content-range = bytes-unit SP (byte-range-resp | unsatisfied-range) - // bytes-unit = "bytes" - // byte-range-resp = byte-range "/" (complete-length | "*") - // unsatisfied-range = "*/" complete-length - // byte-range = first-byte-pos "-" last-byte-pos - // We do not support other-content-range. - // All other identifiers are 1*DIGIT. - hdr = strings.Trim(hdr, ows) // per RFC 7230, section 3.2 - suffix, hasUnits := strings.CutPrefix(hdr, "bytes ") - suffix, unsatisfied := strings.CutPrefix(suffix, "*/") - if unsatisfied { // i.e., unsatisfied-range - n, ok := parseNumber(suffix) - if !ok { - return start, length, completeLength, false - } - completeLength = n - } else { // i.e., byte-range "/" (complete-length | "*") - prefix, suffix, hasDash := strings.Cut(suffix, "-") - middle, suffix, hasSlash := strings.Cut(suffix, "/") - n, ok0 := parseNumber(prefix) - m, ok1 := parseNumber(middle) - o, ok2 := parseNumber(suffix) - if suffix == "*" { - o, ok2 = -1, true - } - if !hasDash || !hasSlash || !ok0 || !ok1 || !ok2 || m < n || (o >= 0 && o <= m) { - return start, length, completeLength, false - } - start = n - length = m - n + 1 - completeLength = o - } - return start, length, completeLength, hasUnits -} - -// FormatContentRange parses a "Content-Range" header per RFC 7233, section 4.2. -// It only handles "Content-Range" headers where the units is "bytes". -// The "Content-Range" header is usually only specified in HTTP responses. -// -// If start and length are non-positive, then it encodes just the completeLength, -// which must be a non-negative value. -// -// Otherwise, it encodes the start and length as a byte-range, -// and optionally emits the complete length if it is non-negative. -// The length must be positive (as RFC 7233 uses inclusive end offsets). -func FormatContentRange(start, length, completeLength int64) (hdr string, ok bool) { - b := []byte("bytes ") - switch { - case start <= 0 && length <= 0 && completeLength >= 0: // i.e., unsatisfied-range - b = append(b, "*/"...) - b = strconv.AppendUint(b, uint64(completeLength), 10) - ok = true - case start >= 0 && length > 0: // i.e., byte-range "/" (complete-length | "*") - b = strconv.AppendUint(b, uint64(start), 10) - b = append(b, '-') - b = strconv.AppendUint(b, uint64(start+length-1), 10) - b = append(b, '/') - if completeLength >= 0 { - b = strconv.AppendUint(b, uint64(completeLength), 10) - ok = completeLength >= start+length && start+length > 0 - } else { - b = append(b, '*') - ok = true - } - } - return string(b), ok -} - -// parseNumber parses s as an unsigned decimal integer. -// It parses according to the 1*DIGIT grammar, which allows leading zeros. -func parseNumber(s string) (int64, bool) { - suffix := strings.TrimLeft(s, "0123456789") - prefix := s[:len(s)-len(suffix)] - n, err := strconv.ParseInt(prefix, 10, 64) - return n, suffix == "" && err == nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package httphdr implements functionality for parsing and formatting +// standard HTTP headers. +package httphdr + +import ( + "bytes" + "strconv" + "strings" +) + +// Range is a range of bytes within some content. +type Range struct { + // Start is the starting offset. + // It is zero if Length is negative; it must not be negative. + Start int64 + // Length is the length of the content. + // It is zero if the length extends to the end of the content. + // It is negative if the length is relative to the end (e.g., last 5 bytes). + Length int64 +} + +// ows is optional whitespace. +const ows = " \t" // per RFC 7230, section 3.2.3 + +// ParseRange parses a "Range" header per RFC 7233, section 3. +// It only handles "Range" headers where the units is "bytes". +// The "Range" header is usually only specified in GET requests. +func ParseRange(hdr string) (ranges []Range, ok bool) { + // Grammar per RFC 7233, appendix D: + // Range = byte-ranges-specifier | other-ranges-specifier + // byte-ranges-specifier = bytes-unit "=" byte-range-set + // bytes-unit = "bytes" + // byte-range-set = + // *("," OWS) + // (byte-range-spec | suffix-byte-range-spec) + // *(OWS "," [OWS ( byte-range-spec | suffix-byte-range-spec )]) + // byte-range-spec = first-byte-pos "-" [last-byte-pos] + // suffix-byte-range-spec = "-" suffix-length + // We do not support other-ranges-specifier. + // All other identifiers are 1*DIGIT. + hdr = strings.Trim(hdr, ows) // per RFC 7230, section 3.2 + units, elems, hasUnits := strings.Cut(hdr, "=") + elems = strings.TrimLeft(elems, ","+ows) + for _, elem := range strings.Split(elems, ",") { + elem = strings.Trim(elem, ows) // per RFC 7230, section 7 + switch { + case strings.HasPrefix(elem, "-"): // i.e., "-" suffix-length + n, ok := parseNumber(strings.TrimPrefix(elem, "-")) + if !ok { + return ranges, false + } + ranges = append(ranges, Range{0, -n}) + case strings.HasSuffix(elem, "-"): // i.e., first-byte-pos "-" + n, ok := parseNumber(strings.TrimSuffix(elem, "-")) + if !ok { + return ranges, false + } + ranges = append(ranges, Range{n, 0}) + default: // i.e., first-byte-pos "-" last-byte-pos + prefix, suffix, hasDash := strings.Cut(elem, "-") + n, ok2 := parseNumber(prefix) + m, ok3 := parseNumber(suffix) + if !hasDash || !ok2 || !ok3 || m < n { + return ranges, false + } + ranges = append(ranges, Range{n, m - n + 1}) + } + } + return ranges, units == "bytes" && hasUnits && len(ranges) > 0 // must see at least one element per RFC 7233, section 2.1 +} + +// FormatRange formats a "Range" header per RFC 7233, section 3. +// It only handles "Range" headers where the units is "bytes". +// The "Range" header is usually only specified in GET requests. +func FormatRange(ranges []Range) (hdr string, ok bool) { + b := []byte("bytes=") + for _, r := range ranges { + switch { + case r.Length > 0: // i.e., first-byte-pos "-" last-byte-pos + if r.Start < 0 { + return string(b), false + } + b = strconv.AppendUint(b, uint64(r.Start), 10) + b = append(b, '-') + b = strconv.AppendUint(b, uint64(r.Start+r.Length-1), 10) + b = append(b, ',') + case r.Length == 0: // i.e., first-byte-pos "-" + if r.Start < 0 { + return string(b), false + } + b = strconv.AppendUint(b, uint64(r.Start), 10) + b = append(b, '-') + b = append(b, ',') + case r.Length < 0: // i.e., "-" suffix-length + if r.Start != 0 { + return string(b), false + } + b = append(b, '-') + b = strconv.AppendUint(b, uint64(-r.Length), 10) + b = append(b, ',') + default: + return string(b), false + } + } + return string(bytes.TrimRight(b, ",")), len(ranges) > 0 +} + +// ParseContentRange parses a "Content-Range" header per RFC 7233, section 4.2. +// It only handles "Content-Range" headers where the units is "bytes". +// The "Content-Range" header is usually only specified in HTTP responses. +// +// If only the completeLength is specified, then start and length are both zero. +// +// Otherwise, the parses the start and length and the optional completeLength, +// which is -1 if unspecified. The start is non-negative and the length is positive. +func ParseContentRange(hdr string) (start, length, completeLength int64, ok bool) { + // Grammar per RFC 7233, appendix D: + // Content-Range = byte-content-range | other-content-range + // byte-content-range = bytes-unit SP (byte-range-resp | unsatisfied-range) + // bytes-unit = "bytes" + // byte-range-resp = byte-range "/" (complete-length | "*") + // unsatisfied-range = "*/" complete-length + // byte-range = first-byte-pos "-" last-byte-pos + // We do not support other-content-range. + // All other identifiers are 1*DIGIT. + hdr = strings.Trim(hdr, ows) // per RFC 7230, section 3.2 + suffix, hasUnits := strings.CutPrefix(hdr, "bytes ") + suffix, unsatisfied := strings.CutPrefix(suffix, "*/") + if unsatisfied { // i.e., unsatisfied-range + n, ok := parseNumber(suffix) + if !ok { + return start, length, completeLength, false + } + completeLength = n + } else { // i.e., byte-range "/" (complete-length | "*") + prefix, suffix, hasDash := strings.Cut(suffix, "-") + middle, suffix, hasSlash := strings.Cut(suffix, "/") + n, ok0 := parseNumber(prefix) + m, ok1 := parseNumber(middle) + o, ok2 := parseNumber(suffix) + if suffix == "*" { + o, ok2 = -1, true + } + if !hasDash || !hasSlash || !ok0 || !ok1 || !ok2 || m < n || (o >= 0 && o <= m) { + return start, length, completeLength, false + } + start = n + length = m - n + 1 + completeLength = o + } + return start, length, completeLength, hasUnits +} + +// FormatContentRange parses a "Content-Range" header per RFC 7233, section 4.2. +// It only handles "Content-Range" headers where the units is "bytes". +// The "Content-Range" header is usually only specified in HTTP responses. +// +// If start and length are non-positive, then it encodes just the completeLength, +// which must be a non-negative value. +// +// Otherwise, it encodes the start and length as a byte-range, +// and optionally emits the complete length if it is non-negative. +// The length must be positive (as RFC 7233 uses inclusive end offsets). +func FormatContentRange(start, length, completeLength int64) (hdr string, ok bool) { + b := []byte("bytes ") + switch { + case start <= 0 && length <= 0 && completeLength >= 0: // i.e., unsatisfied-range + b = append(b, "*/"...) + b = strconv.AppendUint(b, uint64(completeLength), 10) + ok = true + case start >= 0 && length > 0: // i.e., byte-range "/" (complete-length | "*") + b = strconv.AppendUint(b, uint64(start), 10) + b = append(b, '-') + b = strconv.AppendUint(b, uint64(start+length-1), 10) + b = append(b, '/') + if completeLength >= 0 { + b = strconv.AppendUint(b, uint64(completeLength), 10) + ok = completeLength >= start+length && start+length > 0 + } else { + b = append(b, '*') + ok = true + } + } + return string(b), ok +} + +// parseNumber parses s as an unsigned decimal integer. +// It parses according to the 1*DIGIT grammar, which allows leading zeros. +func parseNumber(s string) (int64, bool) { + suffix := strings.TrimLeft(s, "0123456789") + prefix := s[:len(s)-len(suffix)] + n, err := strconv.ParseInt(prefix, 10, 64) + return n, suffix == "" && err == nil +} diff --git a/util/httphdr/httphdr_test.go b/util/httphdr/httphdr_test.go index 77ec0c324..81feeaca0 100644 --- a/util/httphdr/httphdr_test.go +++ b/util/httphdr/httphdr_test.go @@ -1,96 +1,96 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package httphdr - -import ( - "testing" - - "github.com/google/go-cmp/cmp" -) - -func valOk[T any](v T, ok bool) (out struct { - V T - Ok bool -}) { - out.V = v - out.Ok = ok - return out -} - -func TestRange(t *testing.T) { - tests := []struct { - in string - want []Range - wantOk bool - roundtrip bool - }{ - {"", nil, false, false}, - {"1-3", nil, false, false}, - {"units=1-3", []Range{{1, 3}}, false, false}, - {"bytes=1-3", []Range{{1, 3}}, true, true}, - {"bytes=#-3", nil, false, false}, - {"bytes=#-", nil, false, false}, - {"bytes=13", nil, false, false}, - {"bytes=1-#", nil, false, false}, - {"bytes=-#", nil, false, false}, - {"bytes= , , , ,\t , \t 1-3", []Range{{1, 3}}, true, false}, - {"bytes=1-1", []Range{{1, 1}}, true, true}, - {"bytes=01-01", []Range{{1, 1}}, true, false}, - {"bytes=1-0", nil, false, false}, - {"bytes=0-5,2-3", []Range{{0, 6}, {2, 2}}, true, true}, - {"bytes=2-3,0-5", []Range{{2, 2}, {0, 6}}, true, true}, - {"bytes=0-5,2-,-5", []Range{{0, 6}, {2, 0}, {0, -5}}, true, true}, - } - - for _, tt := range tests { - got, gotOk := ParseRange(tt.in) - if d := cmp.Diff(valOk(got, gotOk), valOk(tt.want, tt.wantOk)); d != "" { - t.Errorf("ParseRange(%q) mismatch (-got +want):\n%s", tt.in, d) - } - if tt.roundtrip { - got, gotOk := FormatRange(tt.want) - if d := cmp.Diff(valOk(got, gotOk), valOk(tt.in, tt.wantOk)); d != "" { - t.Errorf("FormatRange(%v) mismatch (-got +want):\n%s", tt.want, d) - } - } - } -} - -type contentRange struct{ Start, Length, CompleteLength int64 } - -func TestContentRange(t *testing.T) { - tests := []struct { - in string - want contentRange - wantOk bool - roundtrip bool - }{ - {"", contentRange{}, false, false}, - {"bytes 5-6/*", contentRange{5, 2, -1}, true, true}, - {"units 5-6/*", contentRange{}, false, false}, - {"bytes 5-6/*", contentRange{}, false, false}, - {"bytes 5-5/*", contentRange{5, 1, -1}, true, true}, - {"bytes 5-4/*", contentRange{}, false, false}, - {"bytes 5-5/6", contentRange{5, 1, 6}, true, true}, - {"bytes 05-005/0006", contentRange{5, 1, 6}, true, false}, - {"bytes 5-5/5", contentRange{}, false, false}, - {"bytes #-5/6", contentRange{}, false, false}, - {"bytes 5-#/6", contentRange{}, false, false}, - {"bytes 5-5/#", contentRange{}, false, false}, - } - - for _, tt := range tests { - start, length, completeLength, gotOk := ParseContentRange(tt.in) - got := contentRange{start, length, completeLength} - if d := cmp.Diff(valOk(got, gotOk), valOk(tt.want, tt.wantOk)); d != "" { - t.Errorf("ParseContentRange mismatch (-got +want):\n%s", d) - } - if tt.roundtrip { - got, gotOk := FormatContentRange(tt.want.Start, tt.want.Length, tt.want.CompleteLength) - if d := cmp.Diff(valOk(got, gotOk), valOk(tt.in, tt.wantOk)); d != "" { - t.Errorf("FormatContentRange mismatch (-got +want):\n%s", d) - } - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package httphdr + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func valOk[T any](v T, ok bool) (out struct { + V T + Ok bool +}) { + out.V = v + out.Ok = ok + return out +} + +func TestRange(t *testing.T) { + tests := []struct { + in string + want []Range + wantOk bool + roundtrip bool + }{ + {"", nil, false, false}, + {"1-3", nil, false, false}, + {"units=1-3", []Range{{1, 3}}, false, false}, + {"bytes=1-3", []Range{{1, 3}}, true, true}, + {"bytes=#-3", nil, false, false}, + {"bytes=#-", nil, false, false}, + {"bytes=13", nil, false, false}, + {"bytes=1-#", nil, false, false}, + {"bytes=-#", nil, false, false}, + {"bytes= , , , ,\t , \t 1-3", []Range{{1, 3}}, true, false}, + {"bytes=1-1", []Range{{1, 1}}, true, true}, + {"bytes=01-01", []Range{{1, 1}}, true, false}, + {"bytes=1-0", nil, false, false}, + {"bytes=0-5,2-3", []Range{{0, 6}, {2, 2}}, true, true}, + {"bytes=2-3,0-5", []Range{{2, 2}, {0, 6}}, true, true}, + {"bytes=0-5,2-,-5", []Range{{0, 6}, {2, 0}, {0, -5}}, true, true}, + } + + for _, tt := range tests { + got, gotOk := ParseRange(tt.in) + if d := cmp.Diff(valOk(got, gotOk), valOk(tt.want, tt.wantOk)); d != "" { + t.Errorf("ParseRange(%q) mismatch (-got +want):\n%s", tt.in, d) + } + if tt.roundtrip { + got, gotOk := FormatRange(tt.want) + if d := cmp.Diff(valOk(got, gotOk), valOk(tt.in, tt.wantOk)); d != "" { + t.Errorf("FormatRange(%v) mismatch (-got +want):\n%s", tt.want, d) + } + } + } +} + +type contentRange struct{ Start, Length, CompleteLength int64 } + +func TestContentRange(t *testing.T) { + tests := []struct { + in string + want contentRange + wantOk bool + roundtrip bool + }{ + {"", contentRange{}, false, false}, + {"bytes 5-6/*", contentRange{5, 2, -1}, true, true}, + {"units 5-6/*", contentRange{}, false, false}, + {"bytes 5-6/*", contentRange{}, false, false}, + {"bytes 5-5/*", contentRange{5, 1, -1}, true, true}, + {"bytes 5-4/*", contentRange{}, false, false}, + {"bytes 5-5/6", contentRange{5, 1, 6}, true, true}, + {"bytes 05-005/0006", contentRange{5, 1, 6}, true, false}, + {"bytes 5-5/5", contentRange{}, false, false}, + {"bytes #-5/6", contentRange{}, false, false}, + {"bytes 5-#/6", contentRange{}, false, false}, + {"bytes 5-5/#", contentRange{}, false, false}, + } + + for _, tt := range tests { + start, length, completeLength, gotOk := ParseContentRange(tt.in) + got := contentRange{start, length, completeLength} + if d := cmp.Diff(valOk(got, gotOk), valOk(tt.want, tt.wantOk)); d != "" { + t.Errorf("ParseContentRange mismatch (-got +want):\n%s", d) + } + if tt.roundtrip { + got, gotOk := FormatContentRange(tt.want.Start, tt.want.Length, tt.want.CompleteLength) + if d := cmp.Diff(valOk(got, gotOk), valOk(tt.in, tt.wantOk)); d != "" { + t.Errorf("FormatContentRange mismatch (-got +want):\n%s", d) + } + } + } +} diff --git a/util/httpm/httpm.go b/util/httpm/httpm.go index 05292f0fa..a9a691b8a 100644 --- a/util/httpm/httpm.go +++ b/util/httpm/httpm.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package httpm has shorter names for HTTP method constants. -// -// Some background: originally Go didn't have http.MethodGet, http.MethodPost -// and life was good and people just wrote readable "GET" and "POST". But then -// in a moment of weakness Brad and others maintaining net/http caved and let -// the http.MethodFoo constants be added and code's been less readable since. -// Now the substance of the method name is hidden away at the end after -// "http.Method" and they all blend together and it's hard to read code using -// them. -// -// This package is a compromise. It provides constants, but shorter and closer -// to how it used to look. It does violate Go style -// (https://github.com/golang/go/wiki/CodeReviewComments#mixed-caps) that says -// constants shouldn't be SCREAM_CASE. But this isn't INT_MAX; it's GET and -// POST, which are already defined as all caps. -// -// It would be tempting to make these constants be typed but then they wouldn't -// be assignable to things in net/http that just want string. Oh well. -package httpm - -const ( - GET = "GET" - HEAD = "HEAD" - POST = "POST" - PUT = "PUT" - PATCH = "PATCH" - DELETE = "DELETE" - CONNECT = "CONNECT" - OPTIONS = "OPTIONS" - TRACE = "TRACE" - SPACEJUMP = "SPACEJUMP" // https://www.w3.org/Protocols/HTTP/Methods/SpaceJump.html - BREW = "BREW" // https://datatracker.ietf.org/doc/html/rfc2324#section-2.1.1 -) +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package httpm has shorter names for HTTP method constants. +// +// Some background: originally Go didn't have http.MethodGet, http.MethodPost +// and life was good and people just wrote readable "GET" and "POST". But then +// in a moment of weakness Brad and others maintaining net/http caved and let +// the http.MethodFoo constants be added and code's been less readable since. +// Now the substance of the method name is hidden away at the end after +// "http.Method" and they all blend together and it's hard to read code using +// them. +// +// This package is a compromise. It provides constants, but shorter and closer +// to how it used to look. It does violate Go style +// (https://github.com/golang/go/wiki/CodeReviewComments#mixed-caps) that says +// constants shouldn't be SCREAM_CASE. But this isn't INT_MAX; it's GET and +// POST, which are already defined as all caps. +// +// It would be tempting to make these constants be typed but then they wouldn't +// be assignable to things in net/http that just want string. Oh well. +package httpm + +const ( + GET = "GET" + HEAD = "HEAD" + POST = "POST" + PUT = "PUT" + PATCH = "PATCH" + DELETE = "DELETE" + CONNECT = "CONNECT" + OPTIONS = "OPTIONS" + TRACE = "TRACE" + SPACEJUMP = "SPACEJUMP" // https://www.w3.org/Protocols/HTTP/Methods/SpaceJump.html + BREW = "BREW" // https://datatracker.ietf.org/doc/html/rfc2324#section-2.1.1 +) diff --git a/util/httpm/httpm_test.go b/util/httpm/httpm_test.go index cbe327d95..0c71edc2f 100644 --- a/util/httpm/httpm_test.go +++ b/util/httpm/httpm_test.go @@ -1,37 +1,37 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package httpm - -import ( - "os" - "os/exec" - "path/filepath" - "strings" - "testing" -) - -func TestUsedConsistently(t *testing.T) { - dir, err := os.Getwd() - if err != nil { - t.Fatal(err) - } - rootDir := filepath.Join(dir, "../..") - - // If we don't have a .git directory, we're not in a git checkout (e.g. - // a downstream package); skip this test. - if _, err := os.Stat(filepath.Join(rootDir, ".git")); err != nil { - t.Skipf("skipping test since .git doesn't exist: %v", err) - } - - cmd := exec.Command("git", "grep", "-l", "-F", "http.Method") - cmd.Dir = rootDir - matches, _ := cmd.Output() - for _, fn := range strings.Split(strings.TrimSpace(string(matches)), "\n") { - switch fn { - case "util/httpm/httpm.go", "util/httpm/httpm_test.go": - continue - } - t.Errorf("http.MethodFoo constant used in %s; use httpm.FOO instead", fn) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package httpm + +import ( + "os" + "os/exec" + "path/filepath" + "strings" + "testing" +) + +func TestUsedConsistently(t *testing.T) { + dir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + rootDir := filepath.Join(dir, "../..") + + // If we don't have a .git directory, we're not in a git checkout (e.g. + // a downstream package); skip this test. + if _, err := os.Stat(filepath.Join(rootDir, ".git")); err != nil { + t.Skipf("skipping test since .git doesn't exist: %v", err) + } + + cmd := exec.Command("git", "grep", "-l", "-F", "http.Method") + cmd.Dir = rootDir + matches, _ := cmd.Output() + for _, fn := range strings.Split(strings.TrimSpace(string(matches)), "\n") { + switch fn { + case "util/httpm/httpm.go", "util/httpm/httpm_test.go": + continue + } + t.Errorf("http.MethodFoo constant used in %s; use httpm.FOO instead", fn) + } +} diff --git a/util/jsonutil/types.go b/util/jsonutil/types.go index 2ee53f44a..057473249 100644 --- a/util/jsonutil/types.go +++ b/util/jsonutil/types.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package jsonutil - -// Bytes is a byte slice in a json-encoded struct. -// encoding/json assumes that []byte fields are hex-encoded. -// Bytes are not hex-encoded; they are treated the same as strings. -// This can avoid unnecessary allocations due to a round trip through strings. -type Bytes []byte - -func (b *Bytes) UnmarshalText(text []byte) error { - // Copy the contexts of text. - *b = append(*b, text...) - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package jsonutil + +// Bytes is a byte slice in a json-encoded struct. +// encoding/json assumes that []byte fields are hex-encoded. +// Bytes are not hex-encoded; they are treated the same as strings. +// This can avoid unnecessary allocations due to a round trip through strings. +type Bytes []byte + +func (b *Bytes) UnmarshalText(text []byte) error { + // Copy the contexts of text. + *b = append(*b, text...) + return nil +} diff --git a/util/jsonutil/unmarshal.go b/util/jsonutil/unmarshal.go index 13aea0c87..b1eb4ea87 100644 --- a/util/jsonutil/unmarshal.go +++ b/util/jsonutil/unmarshal.go @@ -1,89 +1,89 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package jsonutil provides utilities to improve JSON performance. -// It includes an Unmarshal wrapper that amortizes allocated garbage over subsequent runs -// and a Bytes type to reduce allocations when unmarshalling a non-hex-encoded string into a []byte. -package jsonutil - -import ( - "bytes" - "encoding/json" - "sync" -) - -// decoder is a re-usable json decoder. -type decoder struct { - dec *json.Decoder - r *bytes.Reader -} - -var readerPool = sync.Pool{ - New: func() any { - return bytes.NewReader(nil) - }, -} - -var decoderPool = sync.Pool{ - New: func() any { - var d decoder - d.r = readerPool.Get().(*bytes.Reader) - d.dec = json.NewDecoder(d.r) - return &d - }, -} - -// Unmarshal is similar to encoding/json.Unmarshal. -// There are three major differences: -// -// On error, encoding/json.Unmarshal zeros v. -// This Unmarshal may leave partial data in v. -// Always check the error before using v! -// (Future improvements may remove this bug.) -// -// The errors they return don't always match perfectly. -// If you do error matching more precise than err != nil, -// don't use this Unmarshal. -// -// This Unmarshal allocates considerably less memory. -func Unmarshal(b []byte, v any) error { - d := decoderPool.Get().(*decoder) - d.r.Reset(b) - off := d.dec.InputOffset() - err := d.dec.Decode(v) - d.r.Reset(nil) // don't keep a reference to b - // In case of error, report the offset in this byte slice, - // instead of in the totality of all bytes this decoder has processed. - // It is not possible to make all errors match json.Unmarshal exactly, - // but we can at least try. - switch jsonerr := err.(type) { - case *json.SyntaxError: - jsonerr.Offset -= off - case *json.UnmarshalTypeError: - jsonerr.Offset -= off - case nil: - // json.Unmarshal fails if there's any extra junk in the input. - // json.Decoder does not; see https://github.com/golang/go/issues/36225. - // We need to check for anything left over in the buffer. - if d.dec.More() { - // TODO: Provide a better error message. - // Unfortunately, we can't set the msg field. - // The offset doesn't perfectly match json: - // Ours is at the end of the valid data, - // and theirs is at the beginning of the extra data after whitespace. - // Close enough, though. - err = &json.SyntaxError{Offset: d.dec.InputOffset() - off} - - // TODO: zero v. This is hard; see encoding/json.indirect. - } - } - if err == nil { - decoderPool.Put(d) - } else { - // There might be junk left in the decoder's buffer. - // There's no way to flush it, no Reset method. - // Abandoned the decoder but reuse the reader. - readerPool.Put(d.r) - } - return err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package jsonutil provides utilities to improve JSON performance. +// It includes an Unmarshal wrapper that amortizes allocated garbage over subsequent runs +// and a Bytes type to reduce allocations when unmarshalling a non-hex-encoded string into a []byte. +package jsonutil + +import ( + "bytes" + "encoding/json" + "sync" +) + +// decoder is a re-usable json decoder. +type decoder struct { + dec *json.Decoder + r *bytes.Reader +} + +var readerPool = sync.Pool{ + New: func() any { + return bytes.NewReader(nil) + }, +} + +var decoderPool = sync.Pool{ + New: func() any { + var d decoder + d.r = readerPool.Get().(*bytes.Reader) + d.dec = json.NewDecoder(d.r) + return &d + }, +} + +// Unmarshal is similar to encoding/json.Unmarshal. +// There are three major differences: +// +// On error, encoding/json.Unmarshal zeros v. +// This Unmarshal may leave partial data in v. +// Always check the error before using v! +// (Future improvements may remove this bug.) +// +// The errors they return don't always match perfectly. +// If you do error matching more precise than err != nil, +// don't use this Unmarshal. +// +// This Unmarshal allocates considerably less memory. +func Unmarshal(b []byte, v any) error { + d := decoderPool.Get().(*decoder) + d.r.Reset(b) + off := d.dec.InputOffset() + err := d.dec.Decode(v) + d.r.Reset(nil) // don't keep a reference to b + // In case of error, report the offset in this byte slice, + // instead of in the totality of all bytes this decoder has processed. + // It is not possible to make all errors match json.Unmarshal exactly, + // but we can at least try. + switch jsonerr := err.(type) { + case *json.SyntaxError: + jsonerr.Offset -= off + case *json.UnmarshalTypeError: + jsonerr.Offset -= off + case nil: + // json.Unmarshal fails if there's any extra junk in the input. + // json.Decoder does not; see https://github.com/golang/go/issues/36225. + // We need to check for anything left over in the buffer. + if d.dec.More() { + // TODO: Provide a better error message. + // Unfortunately, we can't set the msg field. + // The offset doesn't perfectly match json: + // Ours is at the end of the valid data, + // and theirs is at the beginning of the extra data after whitespace. + // Close enough, though. + err = &json.SyntaxError{Offset: d.dec.InputOffset() - off} + + // TODO: zero v. This is hard; see encoding/json.indirect. + } + } + if err == nil { + decoderPool.Put(d) + } else { + // There might be junk left in the decoder's buffer. + // There's no way to flush it, no Reset method. + // Abandoned the decoder but reuse the reader. + readerPool.Put(d.r) + } + return err +} diff --git a/util/lineread/lineread.go b/util/lineread/lineread.go index 2a7486e0a..6b01d2b69 100644 --- a/util/lineread/lineread.go +++ b/util/lineread/lineread.go @@ -1,37 +1,37 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package lineread reads lines from files. It's not fancy, but it got repetitive. -package lineread - -import ( - "bufio" - "io" - "os" -) - -// File opens name and calls fn for each line. It returns an error if the Open failed -// or once fn returns an error. -func File(name string, fn func(line []byte) error) error { - f, err := os.Open(name) - if err != nil { - return err - } - defer f.Close() - return Reader(f, fn) -} - -// Reader calls fn for each line. -// If fn returns an error, Reader stops reading and returns that error. -// Reader may also return errors encountered reading and parsing from r. -// To stop reading early, use a sentinel "stop" error value and ignore -// it when returned from Reader. -func Reader(r io.Reader, fn func(line []byte) error) error { - bs := bufio.NewScanner(r) - for bs.Scan() { - if err := fn(bs.Bytes()); err != nil { - return err - } - } - return bs.Err() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package lineread reads lines from files. It's not fancy, but it got repetitive. +package lineread + +import ( + "bufio" + "io" + "os" +) + +// File opens name and calls fn for each line. It returns an error if the Open failed +// or once fn returns an error. +func File(name string, fn func(line []byte) error) error { + f, err := os.Open(name) + if err != nil { + return err + } + defer f.Close() + return Reader(f, fn) +} + +// Reader calls fn for each line. +// If fn returns an error, Reader stops reading and returns that error. +// Reader may also return errors encountered reading and parsing from r. +// To stop reading early, use a sentinel "stop" error value and ignore +// it when returned from Reader. +func Reader(r io.Reader, fn func(line []byte) error) error { + bs := bufio.NewScanner(r) + for bs.Scan() { + if err := fn(bs.Bytes()); err != nil { + return err + } + } + return bs.Err() +} diff --git a/util/linuxfw/linuxfwtest/linuxfwtest.go b/util/linuxfw/linuxfwtest/linuxfwtest.go index 04f179199..ee2cbd1b2 100644 --- a/util/linuxfw/linuxfwtest/linuxfwtest.go +++ b/util/linuxfw/linuxfwtest/linuxfwtest.go @@ -1,31 +1,31 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build cgo && linux - -// Package linuxfwtest contains tests for the linuxfw package. Go does not -// support cgo in tests, and we don't want the main package to have a cgo -// dependency, so we put all the tests here and call them from the main package -// in tests intead. -package linuxfwtest - -import ( - "testing" - "unsafe" -) - -/* -#include // socket() -*/ -import "C" - -type SizeInfo struct { - SizeofSocklen uintptr -} - -func TestSizes(t *testing.T, si *SizeInfo) { - want := unsafe.Sizeof(C.socklen_t(0)) - if want != si.SizeofSocklen { - t.Errorf("sockLen has wrong size; want=%d got=%d", want, si.SizeofSocklen) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build cgo && linux + +// Package linuxfwtest contains tests for the linuxfw package. Go does not +// support cgo in tests, and we don't want the main package to have a cgo +// dependency, so we put all the tests here and call them from the main package +// in tests intead. +package linuxfwtest + +import ( + "testing" + "unsafe" +) + +/* +#include // socket() +*/ +import "C" + +type SizeInfo struct { + SizeofSocklen uintptr +} + +func TestSizes(t *testing.T, si *SizeInfo) { + want := unsafe.Sizeof(C.socklen_t(0)) + if want != si.SizeofSocklen { + t.Errorf("sockLen has wrong size; want=%d got=%d", want, si.SizeofSocklen) + } +} diff --git a/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go b/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go index d5e297da7..6e9569900 100644 --- a/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go +++ b/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !cgo || !linux - -package linuxfwtest - -import ( - "testing" -) - -type SizeInfo struct { - SizeofSocklen uintptr -} - -func TestSizes(t *testing.T, si *SizeInfo) { - t.Skip("not supported without cgo") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !cgo || !linux + +package linuxfwtest + +import ( + "testing" +) + +type SizeInfo struct { + SizeofSocklen uintptr +} + +func TestSizes(t *testing.T, si *SizeInfo) { + t.Skip("not supported without cgo") +} diff --git a/util/linuxfw/nftables_types.go b/util/linuxfw/nftables_types.go index a8c5a0730..b6e24d2a6 100644 --- a/util/linuxfw/nftables_types.go +++ b/util/linuxfw/nftables_types.go @@ -1,95 +1,95 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// TODO(#8502): add support for more architectures -//go:build linux && (arm64 || amd64) - -package linuxfw - -import ( - "github.com/google/nftables/expr" - "github.com/google/nftables/xt" -) - -var metaKeyNames = map[expr.MetaKey]string{ - expr.MetaKeyLEN: "LEN", - expr.MetaKeyPROTOCOL: "PROTOCOL", - expr.MetaKeyPRIORITY: "PRIORITY", - expr.MetaKeyMARK: "MARK", - expr.MetaKeyIIF: "IIF", - expr.MetaKeyOIF: "OIF", - expr.MetaKeyIIFNAME: "IIFNAME", - expr.MetaKeyOIFNAME: "OIFNAME", - expr.MetaKeyIIFTYPE: "IIFTYPE", - expr.MetaKeyOIFTYPE: "OIFTYPE", - expr.MetaKeySKUID: "SKUID", - expr.MetaKeySKGID: "SKGID", - expr.MetaKeyNFTRACE: "NFTRACE", - expr.MetaKeyRTCLASSID: "RTCLASSID", - expr.MetaKeySECMARK: "SECMARK", - expr.MetaKeyNFPROTO: "NFPROTO", - expr.MetaKeyL4PROTO: "L4PROTO", - expr.MetaKeyBRIIIFNAME: "BRIIIFNAME", - expr.MetaKeyBRIOIFNAME: "BRIOIFNAME", - expr.MetaKeyPKTTYPE: "PKTTYPE", - expr.MetaKeyCPU: "CPU", - expr.MetaKeyIIFGROUP: "IIFGROUP", - expr.MetaKeyOIFGROUP: "OIFGROUP", - expr.MetaKeyCGROUP: "CGROUP", - expr.MetaKeyPRANDOM: "PRANDOM", -} - -var cmpOpNames = map[expr.CmpOp]string{ - expr.CmpOpEq: "EQ", - expr.CmpOpNeq: "NEQ", - expr.CmpOpLt: "LT", - expr.CmpOpLte: "LTE", - expr.CmpOpGt: "GT", - expr.CmpOpGte: "GTE", -} - -var verdictNames = map[expr.VerdictKind]string{ - expr.VerdictReturn: "RETURN", - expr.VerdictGoto: "GOTO", - expr.VerdictJump: "JUMP", - expr.VerdictBreak: "BREAK", - expr.VerdictContinue: "CONTINUE", - expr.VerdictDrop: "DROP", - expr.VerdictAccept: "ACCEPT", - expr.VerdictStolen: "STOLEN", - expr.VerdictQueue: "QUEUE", - expr.VerdictRepeat: "REPEAT", - expr.VerdictStop: "STOP", -} - -var payloadOperationTypeNames = map[expr.PayloadOperationType]string{ - expr.PayloadLoad: "LOAD", - expr.PayloadWrite: "WRITE", -} - -var payloadBaseNames = map[expr.PayloadBase]string{ - expr.PayloadBaseLLHeader: "ll-header", - expr.PayloadBaseNetworkHeader: "network-header", - expr.PayloadBaseTransportHeader: "transport-header", -} - -var packetTypeNames = map[int]string{ - 0 /* PACKET_HOST */ : "unicast", - 1 /* PACKET_BROADCAST */ : "broadcast", - 2 /* PACKET_MULTICAST */ : "multicast", -} - -var addrTypeFlagNames = map[xt.AddrTypeFlags]string{ - xt.AddrTypeUnspec: "unspec", - xt.AddrTypeUnicast: "unicast", - xt.AddrTypeLocal: "local", - xt.AddrTypeBroadcast: "broadcast", - xt.AddrTypeAnycast: "anycast", - xt.AddrTypeMulticast: "multicast", - xt.AddrTypeBlackhole: "blackhole", - xt.AddrTypeUnreachable: "unreachable", - xt.AddrTypeProhibit: "prohibit", - xt.AddrTypeThrow: "throw", - xt.AddrTypeNat: "nat", - xt.AddrTypeXresolve: "xresolve", -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// TODO(#8502): add support for more architectures +//go:build linux && (arm64 || amd64) + +package linuxfw + +import ( + "github.com/google/nftables/expr" + "github.com/google/nftables/xt" +) + +var metaKeyNames = map[expr.MetaKey]string{ + expr.MetaKeyLEN: "LEN", + expr.MetaKeyPROTOCOL: "PROTOCOL", + expr.MetaKeyPRIORITY: "PRIORITY", + expr.MetaKeyMARK: "MARK", + expr.MetaKeyIIF: "IIF", + expr.MetaKeyOIF: "OIF", + expr.MetaKeyIIFNAME: "IIFNAME", + expr.MetaKeyOIFNAME: "OIFNAME", + expr.MetaKeyIIFTYPE: "IIFTYPE", + expr.MetaKeyOIFTYPE: "OIFTYPE", + expr.MetaKeySKUID: "SKUID", + expr.MetaKeySKGID: "SKGID", + expr.MetaKeyNFTRACE: "NFTRACE", + expr.MetaKeyRTCLASSID: "RTCLASSID", + expr.MetaKeySECMARK: "SECMARK", + expr.MetaKeyNFPROTO: "NFPROTO", + expr.MetaKeyL4PROTO: "L4PROTO", + expr.MetaKeyBRIIIFNAME: "BRIIIFNAME", + expr.MetaKeyBRIOIFNAME: "BRIOIFNAME", + expr.MetaKeyPKTTYPE: "PKTTYPE", + expr.MetaKeyCPU: "CPU", + expr.MetaKeyIIFGROUP: "IIFGROUP", + expr.MetaKeyOIFGROUP: "OIFGROUP", + expr.MetaKeyCGROUP: "CGROUP", + expr.MetaKeyPRANDOM: "PRANDOM", +} + +var cmpOpNames = map[expr.CmpOp]string{ + expr.CmpOpEq: "EQ", + expr.CmpOpNeq: "NEQ", + expr.CmpOpLt: "LT", + expr.CmpOpLte: "LTE", + expr.CmpOpGt: "GT", + expr.CmpOpGte: "GTE", +} + +var verdictNames = map[expr.VerdictKind]string{ + expr.VerdictReturn: "RETURN", + expr.VerdictGoto: "GOTO", + expr.VerdictJump: "JUMP", + expr.VerdictBreak: "BREAK", + expr.VerdictContinue: "CONTINUE", + expr.VerdictDrop: "DROP", + expr.VerdictAccept: "ACCEPT", + expr.VerdictStolen: "STOLEN", + expr.VerdictQueue: "QUEUE", + expr.VerdictRepeat: "REPEAT", + expr.VerdictStop: "STOP", +} + +var payloadOperationTypeNames = map[expr.PayloadOperationType]string{ + expr.PayloadLoad: "LOAD", + expr.PayloadWrite: "WRITE", +} + +var payloadBaseNames = map[expr.PayloadBase]string{ + expr.PayloadBaseLLHeader: "ll-header", + expr.PayloadBaseNetworkHeader: "network-header", + expr.PayloadBaseTransportHeader: "transport-header", +} + +var packetTypeNames = map[int]string{ + 0 /* PACKET_HOST */ : "unicast", + 1 /* PACKET_BROADCAST */ : "broadcast", + 2 /* PACKET_MULTICAST */ : "multicast", +} + +var addrTypeFlagNames = map[xt.AddrTypeFlags]string{ + xt.AddrTypeUnspec: "unspec", + xt.AddrTypeUnicast: "unicast", + xt.AddrTypeLocal: "local", + xt.AddrTypeBroadcast: "broadcast", + xt.AddrTypeAnycast: "anycast", + xt.AddrTypeMulticast: "multicast", + xt.AddrTypeBlackhole: "blackhole", + xt.AddrTypeUnreachable: "unreachable", + xt.AddrTypeProhibit: "prohibit", + xt.AddrTypeThrow: "throw", + xt.AddrTypeNat: "nat", + xt.AddrTypeXresolve: "xresolve", +} diff --git a/util/mak/mak.go b/util/mak/mak.go index b0d64daa4..b421fb0ed 100644 --- a/util/mak/mak.go +++ b/util/mak/mak.go @@ -1,70 +1,70 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package mak helps make maps. It contains generic helpers to make/assign -// things, notably to maps, but also slices. -package mak - -import ( - "fmt" - "reflect" -) - -// Set populates an entry in a map, making the map if necessary. -// -// That is, it assigns (*m)[k] = v, making *m if it was nil. -func Set[K comparable, V any, T ~map[K]V](m *T, k K, v V) { - if *m == nil { - *m = make(map[K]V) - } - (*m)[k] = v -} - -// NonNil takes a pointer to a Go data structure -// (currently only a slice or a map) and makes sure it's non-nil for -// JSON serialization. (In particular, JavaScript clients usually want -// the field to be defined after they decode the JSON.) -// -// Deprecated: use NonNilSliceForJSON or NonNilMapForJSON instead. -func NonNil(ptr any) { - if ptr == nil { - panic("nil interface") - } - rv := reflect.ValueOf(ptr) - if rv.Kind() != reflect.Ptr { - panic(fmt.Sprintf("kind %v, not Ptr", rv.Kind())) - } - if rv.Pointer() == 0 { - panic("nil pointer") - } - rv = rv.Elem() - if rv.Pointer() != 0 { - return - } - switch rv.Type().Kind() { - case reflect.Slice: - rv.Set(reflect.MakeSlice(rv.Type(), 0, 0)) - case reflect.Map: - rv.Set(reflect.MakeMap(rv.Type())) - } -} - -// NonNilSliceForJSON makes sure that *slicePtr is non-nil so it will -// won't be omitted from JSON serialization and possibly confuse JavaScript -// clients expecting it to be present. -func NonNilSliceForJSON[T any, S ~[]T](slicePtr *S) { - if *slicePtr != nil { - return - } - *slicePtr = make([]T, 0) -} - -// NonNilMapForJSON makes sure that *slicePtr is non-nil so it will -// won't be omitted from JSON serialization and possibly confuse JavaScript -// clients expecting it to be present. -func NonNilMapForJSON[K comparable, V any, M ~map[K]V](mapPtr *M) { - if *mapPtr != nil { - return - } - *mapPtr = make(M) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package mak helps make maps. It contains generic helpers to make/assign +// things, notably to maps, but also slices. +package mak + +import ( + "fmt" + "reflect" +) + +// Set populates an entry in a map, making the map if necessary. +// +// That is, it assigns (*m)[k] = v, making *m if it was nil. +func Set[K comparable, V any, T ~map[K]V](m *T, k K, v V) { + if *m == nil { + *m = make(map[K]V) + } + (*m)[k] = v +} + +// NonNil takes a pointer to a Go data structure +// (currently only a slice or a map) and makes sure it's non-nil for +// JSON serialization. (In particular, JavaScript clients usually want +// the field to be defined after they decode the JSON.) +// +// Deprecated: use NonNilSliceForJSON or NonNilMapForJSON instead. +func NonNil(ptr any) { + if ptr == nil { + panic("nil interface") + } + rv := reflect.ValueOf(ptr) + if rv.Kind() != reflect.Ptr { + panic(fmt.Sprintf("kind %v, not Ptr", rv.Kind())) + } + if rv.Pointer() == 0 { + panic("nil pointer") + } + rv = rv.Elem() + if rv.Pointer() != 0 { + return + } + switch rv.Type().Kind() { + case reflect.Slice: + rv.Set(reflect.MakeSlice(rv.Type(), 0, 0)) + case reflect.Map: + rv.Set(reflect.MakeMap(rv.Type())) + } +} + +// NonNilSliceForJSON makes sure that *slicePtr is non-nil so it will +// won't be omitted from JSON serialization and possibly confuse JavaScript +// clients expecting it to be present. +func NonNilSliceForJSON[T any, S ~[]T](slicePtr *S) { + if *slicePtr != nil { + return + } + *slicePtr = make([]T, 0) +} + +// NonNilMapForJSON makes sure that *slicePtr is non-nil so it will +// won't be omitted from JSON serialization and possibly confuse JavaScript +// clients expecting it to be present. +func NonNilMapForJSON[K comparable, V any, M ~map[K]V](mapPtr *M) { + if *mapPtr != nil { + return + } + *mapPtr = make(M) +} diff --git a/util/mak/mak_test.go b/util/mak/mak_test.go index dc1d7e93d..4de499a9d 100644 --- a/util/mak/mak_test.go +++ b/util/mak/mak_test.go @@ -1,88 +1,88 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package mak contains code to help make things. -package mak - -import ( - "reflect" - "testing" -) - -type M map[string]int - -func TestSet(t *testing.T) { - t.Run("unnamed", func(t *testing.T) { - var m map[string]int - Set(&m, "foo", 42) - Set(&m, "bar", 1) - Set(&m, "bar", 2) - want := map[string]int{ - "foo": 42, - "bar": 2, - } - if got := m; !reflect.DeepEqual(got, want) { - t.Errorf("got %v; want %v", got, want) - } - }) - t.Run("named", func(t *testing.T) { - var m M - Set(&m, "foo", 1) - Set(&m, "bar", 1) - Set(&m, "bar", 2) - want := M{ - "foo": 1, - "bar": 2, - } - if got := m; !reflect.DeepEqual(got, want) { - t.Errorf("got %v; want %v", got, want) - } - }) -} - -func TestNonNil(t *testing.T) { - var s []string - NonNil(&s) - if len(s) != 0 { - t.Errorf("slice len = %d; want 0", len(s)) - } - if s == nil { - t.Error("slice still nil") - } - - s = append(s, "foo") - NonNil(&s) - if len(s) != 1 { - t.Errorf("len = %d; want 1", len(s)) - } - if s[0] != "foo" { - t.Errorf("value = %q; want foo", s) - } - - var m map[string]string - NonNil(&m) - if len(m) != 0 { - t.Errorf("map len = %d; want 0", len(s)) - } - if m == nil { - t.Error("map still nil") - } -} - -func TestNonNilMapForJSON(t *testing.T) { - type M map[string]int - var m M - NonNilMapForJSON(&m) - if m == nil { - t.Fatal("still nil") - } -} - -func TestNonNilSliceForJSON(t *testing.T) { - type S []int - var s S - NonNilSliceForJSON(&s) - if s == nil { - t.Fatal("still nil") - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package mak contains code to help make things. +package mak + +import ( + "reflect" + "testing" +) + +type M map[string]int + +func TestSet(t *testing.T) { + t.Run("unnamed", func(t *testing.T) { + var m map[string]int + Set(&m, "foo", 42) + Set(&m, "bar", 1) + Set(&m, "bar", 2) + want := map[string]int{ + "foo": 42, + "bar": 2, + } + if got := m; !reflect.DeepEqual(got, want) { + t.Errorf("got %v; want %v", got, want) + } + }) + t.Run("named", func(t *testing.T) { + var m M + Set(&m, "foo", 1) + Set(&m, "bar", 1) + Set(&m, "bar", 2) + want := M{ + "foo": 1, + "bar": 2, + } + if got := m; !reflect.DeepEqual(got, want) { + t.Errorf("got %v; want %v", got, want) + } + }) +} + +func TestNonNil(t *testing.T) { + var s []string + NonNil(&s) + if len(s) != 0 { + t.Errorf("slice len = %d; want 0", len(s)) + } + if s == nil { + t.Error("slice still nil") + } + + s = append(s, "foo") + NonNil(&s) + if len(s) != 1 { + t.Errorf("len = %d; want 1", len(s)) + } + if s[0] != "foo" { + t.Errorf("value = %q; want foo", s) + } + + var m map[string]string + NonNil(&m) + if len(m) != 0 { + t.Errorf("map len = %d; want 0", len(s)) + } + if m == nil { + t.Error("map still nil") + } +} + +func TestNonNilMapForJSON(t *testing.T) { + type M map[string]int + var m M + NonNilMapForJSON(&m) + if m == nil { + t.Fatal("still nil") + } +} + +func TestNonNilSliceForJSON(t *testing.T) { + type S []int + var s S + NonNilSliceForJSON(&s) + if s == nil { + t.Fatal("still nil") + } +} diff --git a/util/multierr/multierr.go b/util/multierr/multierr.go index 5ec36f644..93ca068f5 100644 --- a/util/multierr/multierr.go +++ b/util/multierr/multierr.go @@ -1,136 +1,136 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package multierr provides a simple multiple-error type. -// It was inspired by github.com/go-multierror/multierror. -package multierr - -import ( - "errors" - "slices" - "strings" -) - -// An Error represents multiple errors. -type Error struct { - errs []error -} - -// Error implements the error interface. -func (e Error) Error() string { - s := new(strings.Builder) - s.WriteString("multiple errors:") - for _, err := range e.errs { - s.WriteString("\n\t") - s.WriteString(err.Error()) - } - return s.String() -} - -// Errors returns a slice containing all errors in e. -func (e Error) Errors() []error { - return slices.Clone(e.errs) -} - -// Unwrap returns the underlying errors as-is. -func (e Error) Unwrap() []error { - // Do not clone since Unwrap requires callers to not mutate the slice. - // See the documentation in the Go "errors" package. - return e.errs -} - -// New returns an error composed from errs. -// Some errors in errs get special treatment: -// - nil errors are discarded -// - errors of type Error are expanded into the top level -// -// If the resulting slice has length 0, New returns nil. -// If the resulting slice has length 1, New returns that error. -// If the resulting slice has length > 1, New returns that slice as an Error. -func New(errs ...error) error { - // First count the number of errors to avoid allocating. - var n int - var errFirst error - for _, e := range errs { - switch e := e.(type) { - case nil: - continue - case Error: - n += len(e.errs) - if errFirst == nil && len(e.errs) > 0 { - errFirst = e.errs[0] - } - default: - n++ - if errFirst == nil { - errFirst = e - } - } - } - if n <= 1 { - return errFirst // nil if n == 0 - } - - // More than one error, allocate slice and construct the multi-error. - dst := make([]error, 0, n) - for _, e := range errs { - switch e := e.(type) { - case nil: - continue - case Error: - dst = append(dst, e.errs...) - default: - dst = append(dst, e) - } - } - return Error{errs: dst} -} - -// Is reports whether any error in e matches target. -func (e Error) Is(target error) bool { - for _, err := range e.errs { - if errors.Is(err, target) { - return true - } - } - return false -} - -// As finds the first error in e that matches target, and if any is found, -// sets target to that error value and returns true. Otherwise, it returns false. -func (e Error) As(target any) bool { - for _, err := range e.errs { - if ok := errors.As(err, target); ok { - return true - } - } - return false -} - -// Range performs a pre-order, depth-first iteration of the error tree -// by successively unwrapping all error values. -// For each iteration it calls fn with the current error value and -// stops iteration if it ever reports false. -func Range(err error, fn func(error) bool) bool { - if err == nil { - return true - } - if !fn(err) { - return false - } - switch err := err.(type) { - case interface{ Unwrap() error }: - if err := err.Unwrap(); err != nil { - if !Range(err, fn) { - return false - } - } - case interface{ Unwrap() []error }: - for _, err := range err.Unwrap() { - if !Range(err, fn) { - return false - } - } - } - return true -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package multierr provides a simple multiple-error type. +// It was inspired by github.com/go-multierror/multierror. +package multierr + +import ( + "errors" + "slices" + "strings" +) + +// An Error represents multiple errors. +type Error struct { + errs []error +} + +// Error implements the error interface. +func (e Error) Error() string { + s := new(strings.Builder) + s.WriteString("multiple errors:") + for _, err := range e.errs { + s.WriteString("\n\t") + s.WriteString(err.Error()) + } + return s.String() +} + +// Errors returns a slice containing all errors in e. +func (e Error) Errors() []error { + return slices.Clone(e.errs) +} + +// Unwrap returns the underlying errors as-is. +func (e Error) Unwrap() []error { + // Do not clone since Unwrap requires callers to not mutate the slice. + // See the documentation in the Go "errors" package. + return e.errs +} + +// New returns an error composed from errs. +// Some errors in errs get special treatment: +// - nil errors are discarded +// - errors of type Error are expanded into the top level +// +// If the resulting slice has length 0, New returns nil. +// If the resulting slice has length 1, New returns that error. +// If the resulting slice has length > 1, New returns that slice as an Error. +func New(errs ...error) error { + // First count the number of errors to avoid allocating. + var n int + var errFirst error + for _, e := range errs { + switch e := e.(type) { + case nil: + continue + case Error: + n += len(e.errs) + if errFirst == nil && len(e.errs) > 0 { + errFirst = e.errs[0] + } + default: + n++ + if errFirst == nil { + errFirst = e + } + } + } + if n <= 1 { + return errFirst // nil if n == 0 + } + + // More than one error, allocate slice and construct the multi-error. + dst := make([]error, 0, n) + for _, e := range errs { + switch e := e.(type) { + case nil: + continue + case Error: + dst = append(dst, e.errs...) + default: + dst = append(dst, e) + } + } + return Error{errs: dst} +} + +// Is reports whether any error in e matches target. +func (e Error) Is(target error) bool { + for _, err := range e.errs { + if errors.Is(err, target) { + return true + } + } + return false +} + +// As finds the first error in e that matches target, and if any is found, +// sets target to that error value and returns true. Otherwise, it returns false. +func (e Error) As(target any) bool { + for _, err := range e.errs { + if ok := errors.As(err, target); ok { + return true + } + } + return false +} + +// Range performs a pre-order, depth-first iteration of the error tree +// by successively unwrapping all error values. +// For each iteration it calls fn with the current error value and +// stops iteration if it ever reports false. +func Range(err error, fn func(error) bool) bool { + if err == nil { + return true + } + if !fn(err) { + return false + } + switch err := err.(type) { + case interface{ Unwrap() error }: + if err := err.Unwrap(); err != nil { + if !Range(err, fn) { + return false + } + } + case interface{ Unwrap() []error }: + for _, err := range err.Unwrap() { + if !Range(err, fn) { + return false + } + } + } + return true +} diff --git a/util/must/must.go b/util/must/must.go index 056986fca..21965daa9 100644 --- a/util/must/must.go +++ b/util/must/must.go @@ -1,25 +1,25 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package must assists in calling functions that must succeed. -// -// Example usage: -// -// var target = must.Get(url.Parse(...)) -// must.Do(close()) -package must - -// Do panics if err is non-nil. -func Do(err error) { - if err != nil { - panic(err) - } -} - -// Get returns v as is. It panics if err is non-nil. -func Get[T any](v T, err error) T { - if err != nil { - panic(err) - } - return v -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package must assists in calling functions that must succeed. +// +// Example usage: +// +// var target = must.Get(url.Parse(...)) +// must.Do(close()) +package must + +// Do panics if err is non-nil. +func Do(err error) { + if err != nil { + panic(err) + } +} + +// Get returns v as is. It panics if err is non-nil. +func Get[T any](v T, err error) T { + if err != nil { + panic(err) + } + return v +} diff --git a/util/osdiag/mksyscall.go b/util/osdiag/mksyscall.go index f20be7f92..bcbe113b0 100644 --- a/util/osdiag/mksyscall.go +++ b/util/osdiag/mksyscall.go @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package osdiag - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go -//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go - -//sys globalMemoryStatusEx(memStatus *_MEMORYSTATUSEX) (err error) [int32(failretval)==0] = kernel32.GlobalMemoryStatusEx -//sys regEnumValue(key registry.Key, index uint32, valueName *uint16, valueNameLen *uint32, reserved *uint32, valueType *uint32, pData *byte, cbData *uint32) (ret error) [failretval!=0] = advapi32.RegEnumValueW -//sys wscEnumProtocols(iProtocols *int32, protocolBuffer *wsaProtocolInfo, bufLen *uint32, errno *int32) (ret int32) = ws2_32.WSCEnumProtocols -//sys wscGetProviderInfo(providerId *windows.GUID, infoType _WSC_PROVIDER_INFO_TYPE, info unsafe.Pointer, infoSize *uintptr, flags uint32, errno *int32) (ret int32) = ws2_32.WSCGetProviderInfo -//sys wscGetProviderPath(providerId *windows.GUID, providerDllPath *uint16, providerDllPathLen *int32, errno *int32) (ret int32) = ws2_32.WSCGetProviderPath +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package osdiag + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go +//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go + +//sys globalMemoryStatusEx(memStatus *_MEMORYSTATUSEX) (err error) [int32(failretval)==0] = kernel32.GlobalMemoryStatusEx +//sys regEnumValue(key registry.Key, index uint32, valueName *uint16, valueNameLen *uint32, reserved *uint32, valueType *uint32, pData *byte, cbData *uint32) (ret error) [failretval!=0] = advapi32.RegEnumValueW +//sys wscEnumProtocols(iProtocols *int32, protocolBuffer *wsaProtocolInfo, bufLen *uint32, errno *int32) (ret int32) = ws2_32.WSCEnumProtocols +//sys wscGetProviderInfo(providerId *windows.GUID, infoType _WSC_PROVIDER_INFO_TYPE, info unsafe.Pointer, infoSize *uintptr, flags uint32, errno *int32) (ret int32) = ws2_32.WSCGetProviderInfo +//sys wscGetProviderPath(providerId *windows.GUID, providerDllPath *uint16, providerDllPathLen *int32, errno *int32) (ret int32) = ws2_32.WSCGetProviderPath diff --git a/util/osdiag/osdiag_windows_test.go b/util/osdiag/osdiag_windows_test.go index 776852a34..b29b602cc 100644 --- a/util/osdiag/osdiag_windows_test.go +++ b/util/osdiag/osdiag_windows_test.go @@ -1,128 +1,128 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package osdiag - -import ( - "errors" - "fmt" - "maps" - "strings" - "testing" - - "golang.org/x/sys/windows/registry" -) - -func makeLongBinaryValue() []byte { - buf := make([]byte, maxBinaryValueLen*2) - for i, _ := range buf { - buf[i] = byte(i % 0xFF) - } - return buf -} - -var testData = map[string]any{ - "": "I am the default", - "StringEmpty": "", - "StringShort": "Hello", - "StringLong": strings.Repeat("7", initialValueBufLen+1), - "MultiStringEmpty": []string{}, - "MultiStringSingle": []string{"Foo"}, - "MultiStringSingleEmpty": []string{""}, - "MultiString": []string{"Foo", "Bar", "Baz"}, - "MultiStringWithEmptyBeginning": []string{"", "Foo", "Bar"}, - "MultiStringWithEmptyMiddle": []string{"Foo", "", "Bar"}, - "MultiStringWithEmptyEnd": []string{"Foo", "Bar", ""}, - "DWord": uint32(0x12345678), - "QWord": uint64(0x123456789abcdef0), - "BinaryEmpty": []byte{}, - "BinaryShort": []byte{0x01, 0x02, 0x03, 0x04}, - "BinaryLong": makeLongBinaryValue(), -} - -const ( - keyNameTest = `SOFTWARE\Tailscale Test` - subKeyNameTest = "SubKey" -) - -func setValues(t *testing.T, k registry.Key) { - for vk, v := range testData { - var err error - switch tv := v.(type) { - case string: - err = k.SetStringValue(vk, tv) - case []string: - err = k.SetStringsValue(vk, tv) - case uint32: - err = k.SetDWordValue(vk, tv) - case uint64: - err = k.SetQWordValue(vk, tv) - case []byte: - err = k.SetBinaryValue(vk, tv) - default: - t.Fatalf("Unknown type") - } - - if err != nil { - t.Fatalf("Error setting %q: %v", vk, err) - } - } -} - -func TestRegistrySupportInfo(t *testing.T) { - // Make sure the key doesn't exist yet - k, err := registry.OpenKey(registry.CURRENT_USER, keyNameTest, registry.READ) - switch { - case err == nil: - k.Close() - t.Fatalf("Test key already exists") - case !errors.Is(err, registry.ErrNotExist): - t.Fatal(err) - } - - func() { - k, _, err := registry.CreateKey(registry.CURRENT_USER, keyNameTest, registry.WRITE) - if err != nil { - t.Fatalf("Error creating test key: %v", err) - } - defer k.Close() - - setValues(t, k) - - sk, _, err := registry.CreateKey(k, subKeyNameTest, registry.WRITE) - if err != nil { - t.Fatalf("Error creating test subkey: %v", err) - } - defer sk.Close() - - setValues(t, sk) - }() - - t.Cleanup(func() { - registry.DeleteKey(registry.CURRENT_USER, keyNameTest+"\\"+subKeyNameTest) - registry.DeleteKey(registry.CURRENT_USER, keyNameTest) - }) - - wantValuesData := maps.Clone(testData) - wantValuesData["BinaryLong"] = (wantValuesData["BinaryLong"].([]byte))[:maxBinaryValueLen] - - wantKeyData := make(map[string]any) - maps.Copy(wantKeyData, wantValuesData) - wantSubKeyData := make(map[string]any) - maps.Copy(wantSubKeyData, wantValuesData) - wantKeyData[subKeyNameTest] = wantSubKeyData - - wantData := map[string]any{ - "HKCU\\" + keyNameTest: wantKeyData, - } - - gotData, err := getRegistrySupportInfo(registry.CURRENT_USER, []string{keyNameTest}) - if err != nil { - t.Errorf("getRegistrySupportInfo error: %v", err) - } - - want, got := fmt.Sprintf("%#v", wantData), fmt.Sprintf("%#v", gotData) - if want != got { - t.Errorf("Compare error: want\n%s,\ngot %s", want, got) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package osdiag + +import ( + "errors" + "fmt" + "maps" + "strings" + "testing" + + "golang.org/x/sys/windows/registry" +) + +func makeLongBinaryValue() []byte { + buf := make([]byte, maxBinaryValueLen*2) + for i, _ := range buf { + buf[i] = byte(i % 0xFF) + } + return buf +} + +var testData = map[string]any{ + "": "I am the default", + "StringEmpty": "", + "StringShort": "Hello", + "StringLong": strings.Repeat("7", initialValueBufLen+1), + "MultiStringEmpty": []string{}, + "MultiStringSingle": []string{"Foo"}, + "MultiStringSingleEmpty": []string{""}, + "MultiString": []string{"Foo", "Bar", "Baz"}, + "MultiStringWithEmptyBeginning": []string{"", "Foo", "Bar"}, + "MultiStringWithEmptyMiddle": []string{"Foo", "", "Bar"}, + "MultiStringWithEmptyEnd": []string{"Foo", "Bar", ""}, + "DWord": uint32(0x12345678), + "QWord": uint64(0x123456789abcdef0), + "BinaryEmpty": []byte{}, + "BinaryShort": []byte{0x01, 0x02, 0x03, 0x04}, + "BinaryLong": makeLongBinaryValue(), +} + +const ( + keyNameTest = `SOFTWARE\Tailscale Test` + subKeyNameTest = "SubKey" +) + +func setValues(t *testing.T, k registry.Key) { + for vk, v := range testData { + var err error + switch tv := v.(type) { + case string: + err = k.SetStringValue(vk, tv) + case []string: + err = k.SetStringsValue(vk, tv) + case uint32: + err = k.SetDWordValue(vk, tv) + case uint64: + err = k.SetQWordValue(vk, tv) + case []byte: + err = k.SetBinaryValue(vk, tv) + default: + t.Fatalf("Unknown type") + } + + if err != nil { + t.Fatalf("Error setting %q: %v", vk, err) + } + } +} + +func TestRegistrySupportInfo(t *testing.T) { + // Make sure the key doesn't exist yet + k, err := registry.OpenKey(registry.CURRENT_USER, keyNameTest, registry.READ) + switch { + case err == nil: + k.Close() + t.Fatalf("Test key already exists") + case !errors.Is(err, registry.ErrNotExist): + t.Fatal(err) + } + + func() { + k, _, err := registry.CreateKey(registry.CURRENT_USER, keyNameTest, registry.WRITE) + if err != nil { + t.Fatalf("Error creating test key: %v", err) + } + defer k.Close() + + setValues(t, k) + + sk, _, err := registry.CreateKey(k, subKeyNameTest, registry.WRITE) + if err != nil { + t.Fatalf("Error creating test subkey: %v", err) + } + defer sk.Close() + + setValues(t, sk) + }() + + t.Cleanup(func() { + registry.DeleteKey(registry.CURRENT_USER, keyNameTest+"\\"+subKeyNameTest) + registry.DeleteKey(registry.CURRENT_USER, keyNameTest) + }) + + wantValuesData := maps.Clone(testData) + wantValuesData["BinaryLong"] = (wantValuesData["BinaryLong"].([]byte))[:maxBinaryValueLen] + + wantKeyData := make(map[string]any) + maps.Copy(wantKeyData, wantValuesData) + wantSubKeyData := make(map[string]any) + maps.Copy(wantSubKeyData, wantValuesData) + wantKeyData[subKeyNameTest] = wantSubKeyData + + wantData := map[string]any{ + "HKCU\\" + keyNameTest: wantKeyData, + } + + gotData, err := getRegistrySupportInfo(registry.CURRENT_USER, []string{keyNameTest}) + if err != nil { + t.Errorf("getRegistrySupportInfo error: %v", err) + } + + want, got := fmt.Sprintf("%#v", wantData), fmt.Sprintf("%#v", gotData) + if want != got { + t.Errorf("Compare error: want\n%s,\ngot %s", want, got) + } +} diff --git a/util/osshare/filesharingstatus_noop.go b/util/osshare/filesharingstatus_noop.go index 6be4131a9..7f2b13190 100644 --- a/util/osshare/filesharingstatus_noop.go +++ b/util/osshare/filesharingstatus_noop.go @@ -1,12 +1,12 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows - -package osshare - -import ( - "tailscale.com/types/logger" -) - -func SetFileSharingEnabled(enabled bool, logf logger.Logf) {} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package osshare + +import ( + "tailscale.com/types/logger" +) + +func SetFileSharingEnabled(enabled bool, logf logger.Logf) {} diff --git a/util/pidowner/pidowner.go b/util/pidowner/pidowner.go index 62ea85d78..56bb640b7 100644 --- a/util/pidowner/pidowner.go +++ b/util/pidowner/pidowner.go @@ -1,24 +1,24 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package pidowner handles lookups from process ID to its owning user. -package pidowner - -import ( - "errors" - "runtime" -) - -var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) - -var ErrProcessNotFound = errors.New("process not found") - -// OwnerOfPID returns the user ID that owns the given process ID. -// -// The returned user ID is suitable to passing to os/user.LookupId. -// -// The returned error will be ErrNotImplemented for operating systems where -// this isn't supported. -func OwnerOfPID(pid int) (userID string, err error) { - return ownerOfPID(pid) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package pidowner handles lookups from process ID to its owning user. +package pidowner + +import ( + "errors" + "runtime" +) + +var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) + +var ErrProcessNotFound = errors.New("process not found") + +// OwnerOfPID returns the user ID that owns the given process ID. +// +// The returned user ID is suitable to passing to os/user.LookupId. +// +// The returned error will be ErrNotImplemented for operating systems where +// this isn't supported. +func OwnerOfPID(pid int) (userID string, err error) { + return ownerOfPID(pid) +} diff --git a/util/pidowner/pidowner_noimpl.go b/util/pidowner/pidowner_noimpl.go index a631e3f24..50add492f 100644 --- a/util/pidowner/pidowner_noimpl.go +++ b/util/pidowner/pidowner_noimpl.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !linux - -package pidowner - -func ownerOfPID(pid int) (userID string, err error) { return "", ErrNotImplemented } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !linux + +package pidowner + +func ownerOfPID(pid int) (userID string, err error) { return "", ErrNotImplemented } diff --git a/util/pidowner/pidowner_windows.go b/util/pidowner/pidowner_windows.go index c7b2512a4..dbf13ac81 100644 --- a/util/pidowner/pidowner_windows.go +++ b/util/pidowner/pidowner_windows.go @@ -1,35 +1,35 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package pidowner - -import ( - "fmt" - "syscall" - - "golang.org/x/sys/windows" -) - -func ownerOfPID(pid int) (userID string, err error) { - procHnd, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, uint32(pid)) - if err == syscall.Errno(0x57) { // invalid parameter, for PIDs that don't exist - return "", ErrProcessNotFound - } - if err != nil { - return "", fmt.Errorf("OpenProcess: %T %#v", err, err) - } - defer windows.CloseHandle(procHnd) - - var tok windows.Token - if err := windows.OpenProcessToken(procHnd, windows.TOKEN_QUERY, &tok); err != nil { - return "", fmt.Errorf("OpenProcessToken: %w", err) - } - - tokUser, err := tok.GetTokenUser() - if err != nil { - return "", fmt.Errorf("GetTokenUser: %w", err) - } - - sid := tokUser.User.Sid - return sid.String(), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package pidowner + +import ( + "fmt" + "syscall" + + "golang.org/x/sys/windows" +) + +func ownerOfPID(pid int) (userID string, err error) { + procHnd, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, uint32(pid)) + if err == syscall.Errno(0x57) { // invalid parameter, for PIDs that don't exist + return "", ErrProcessNotFound + } + if err != nil { + return "", fmt.Errorf("OpenProcess: %T %#v", err, err) + } + defer windows.CloseHandle(procHnd) + + var tok windows.Token + if err := windows.OpenProcessToken(procHnd, windows.TOKEN_QUERY, &tok); err != nil { + return "", fmt.Errorf("OpenProcessToken: %w", err) + } + + tokUser, err := tok.GetTokenUser() + if err != nil { + return "", fmt.Errorf("GetTokenUser: %w", err) + } + + sid := tokUser.User.Sid + return sid.String(), nil +} diff --git a/util/precompress/precompress.go b/util/precompress/precompress.go index e9bebb333..6d1a26efd 100644 --- a/util/precompress/precompress.go +++ b/util/precompress/precompress.go @@ -1,129 +1,129 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package precompress provides build- and serving-time support for -// precompressed static resources, to avoid the cost of repeatedly compressing -// unchanging resources. -package precompress - -import ( - "bytes" - "compress/gzip" - "io" - "io/fs" - "net/http" - "os" - "path" - "path/filepath" - - "github.com/andybalholm/brotli" - "golang.org/x/sync/errgroup" - "tailscale.com/tsweb" -) - -// PrecompressDir compresses static assets in dirPath using Gzip and Brotli, so -// that they can be later served with OpenPrecompressedFile. -func PrecompressDir(dirPath string, options Options) error { - var eg errgroup.Group - err := fs.WalkDir(os.DirFS(dirPath), ".", func(p string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - return nil - } - if !compressibleExtensions[filepath.Ext(p)] { - return nil - } - p = path.Join(dirPath, p) - if options.ProgressFn != nil { - options.ProgressFn(p) - } - - eg.Go(func() error { - return Precompress(p, options) - }) - return nil - }) - if err != nil { - return err - } - return eg.Wait() -} - -type Options struct { - // FastCompression controls whether compression should be optimized for - // speed rather than size. - FastCompression bool - // ProgressFn, if non-nil, is invoked when a file in the directory is about - // to be compressed. - ProgressFn func(path string) -} - -// OpenPrecompressedFile opens a file from fs, preferring compressed versions -// generated by PrecompressDir if possible. -func OpenPrecompressedFile(w http.ResponseWriter, r *http.Request, path string, fs fs.FS) (fs.File, error) { - if tsweb.AcceptsEncoding(r, "br") { - if f, err := fs.Open(path + ".br"); err == nil { - w.Header().Set("Content-Encoding", "br") - return f, nil - } - } - if tsweb.AcceptsEncoding(r, "gzip") { - if f, err := fs.Open(path + ".gz"); err == nil { - w.Header().Set("Content-Encoding", "gzip") - return f, nil - } - } - - return fs.Open(path) -} - -var compressibleExtensions = map[string]bool{ - ".js": true, - ".css": true, -} - -func Precompress(path string, options Options) error { - contents, err := os.ReadFile(path) - if err != nil { - return err - } - fi, err := os.Lstat(path) - if err != nil { - return err - } - - gzipLevel := gzip.BestCompression - if options.FastCompression { - gzipLevel = gzip.BestSpeed - } - err = writeCompressed(contents, func(w io.Writer) (io.WriteCloser, error) { - return gzip.NewWriterLevel(w, gzipLevel) - }, path+".gz", fi.Mode()) - if err != nil { - return err - } - brotliLevel := brotli.BestCompression - if options.FastCompression { - brotliLevel = brotli.BestSpeed - } - return writeCompressed(contents, func(w io.Writer) (io.WriteCloser, error) { - return brotli.NewWriterLevel(w, brotliLevel), nil - }, path+".br", fi.Mode()) -} - -func writeCompressed(contents []byte, compressedWriterCreator func(io.Writer) (io.WriteCloser, error), outputPath string, outputMode fs.FileMode) error { - var buf bytes.Buffer - compressedWriter, err := compressedWriterCreator(&buf) - if err != nil { - return err - } - if _, err := compressedWriter.Write(contents); err != nil { - return err - } - if err := compressedWriter.Close(); err != nil { - return err - } - return os.WriteFile(outputPath, buf.Bytes(), outputMode) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package precompress provides build- and serving-time support for +// precompressed static resources, to avoid the cost of repeatedly compressing +// unchanging resources. +package precompress + +import ( + "bytes" + "compress/gzip" + "io" + "io/fs" + "net/http" + "os" + "path" + "path/filepath" + + "github.com/andybalholm/brotli" + "golang.org/x/sync/errgroup" + "tailscale.com/tsweb" +) + +// PrecompressDir compresses static assets in dirPath using Gzip and Brotli, so +// that they can be later served with OpenPrecompressedFile. +func PrecompressDir(dirPath string, options Options) error { + var eg errgroup.Group + err := fs.WalkDir(os.DirFS(dirPath), ".", func(p string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + return nil + } + if !compressibleExtensions[filepath.Ext(p)] { + return nil + } + p = path.Join(dirPath, p) + if options.ProgressFn != nil { + options.ProgressFn(p) + } + + eg.Go(func() error { + return Precompress(p, options) + }) + return nil + }) + if err != nil { + return err + } + return eg.Wait() +} + +type Options struct { + // FastCompression controls whether compression should be optimized for + // speed rather than size. + FastCompression bool + // ProgressFn, if non-nil, is invoked when a file in the directory is about + // to be compressed. + ProgressFn func(path string) +} + +// OpenPrecompressedFile opens a file from fs, preferring compressed versions +// generated by PrecompressDir if possible. +func OpenPrecompressedFile(w http.ResponseWriter, r *http.Request, path string, fs fs.FS) (fs.File, error) { + if tsweb.AcceptsEncoding(r, "br") { + if f, err := fs.Open(path + ".br"); err == nil { + w.Header().Set("Content-Encoding", "br") + return f, nil + } + } + if tsweb.AcceptsEncoding(r, "gzip") { + if f, err := fs.Open(path + ".gz"); err == nil { + w.Header().Set("Content-Encoding", "gzip") + return f, nil + } + } + + return fs.Open(path) +} + +var compressibleExtensions = map[string]bool{ + ".js": true, + ".css": true, +} + +func Precompress(path string, options Options) error { + contents, err := os.ReadFile(path) + if err != nil { + return err + } + fi, err := os.Lstat(path) + if err != nil { + return err + } + + gzipLevel := gzip.BestCompression + if options.FastCompression { + gzipLevel = gzip.BestSpeed + } + err = writeCompressed(contents, func(w io.Writer) (io.WriteCloser, error) { + return gzip.NewWriterLevel(w, gzipLevel) + }, path+".gz", fi.Mode()) + if err != nil { + return err + } + brotliLevel := brotli.BestCompression + if options.FastCompression { + brotliLevel = brotli.BestSpeed + } + return writeCompressed(contents, func(w io.Writer) (io.WriteCloser, error) { + return brotli.NewWriterLevel(w, brotliLevel), nil + }, path+".br", fi.Mode()) +} + +func writeCompressed(contents []byte, compressedWriterCreator func(io.Writer) (io.WriteCloser, error), outputPath string, outputMode fs.FileMode) error { + var buf bytes.Buffer + compressedWriter, err := compressedWriterCreator(&buf) + if err != nil { + return err + } + if _, err := compressedWriter.Write(contents); err != nil { + return err + } + if err := compressedWriter.Close(); err != nil { + return err + } + return os.WriteFile(outputPath, buf.Bytes(), outputMode) +} diff --git a/util/quarantine/quarantine.go b/util/quarantine/quarantine.go index 488465ba0..7ad65a81d 100644 --- a/util/quarantine/quarantine.go +++ b/util/quarantine/quarantine.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package quarantine sets platform specific "quarantine" attributes on files -// that are received from other hosts. -package quarantine - -import "os" - -// SetOnFile sets the platform-specific quarantine attribute (if any) on the -// provided file. -func SetOnFile(f *os.File) error { - return setQuarantineAttr(f) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package quarantine sets platform specific "quarantine" attributes on files +// that are received from other hosts. +package quarantine + +import "os" + +// SetOnFile sets the platform-specific quarantine attribute (if any) on the +// provided file. +func SetOnFile(f *os.File) error { + return setQuarantineAttr(f) +} diff --git a/util/quarantine/quarantine_darwin.go b/util/quarantine/quarantine_darwin.go index b7757f334..35405d9cc 100644 --- a/util/quarantine/quarantine_darwin.go +++ b/util/quarantine/quarantine_darwin.go @@ -1,56 +1,56 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package quarantine - -import ( - "fmt" - "os" - "strings" - "time" - - "github.com/google/uuid" - "golang.org/x/sys/unix" -) - -func setQuarantineAttr(f *os.File) error { - sc, err := f.SyscallConn() - if err != nil { - return err - } - - now := time.Now() - - // We uppercase the UUID to match what other applications on macOS do - id := strings.ToUpper(uuid.New().String()) - - // kLSQuarantineTypeOtherDownload; this matches what AirDrop sets when - // receiving a file. - quarantineType := "0001" - - // This format is under-documented, but the following links contain a - // reasonably comprehensive overview: - // https://eclecticlight.co/2020/10/29/quarantine-and-the-quarantine-flag/ - // https://nixhacker.com/security-protection-in-macos-1/ - // https://ilostmynotes.blogspot.com/2012/06/gatekeeper-xprotect-and-quarantine.html - attrData := fmt.Sprintf("%s;%x;%s;%s", - quarantineType, // quarantine value - now.Unix(), // time in hex - "Tailscale", // application - id, // UUID - ) - - var innerErr error - err = sc.Control(func(fd uintptr) { - innerErr = unix.Fsetxattr( - int(fd), - "com.apple.quarantine", // attr - []byte(attrData), - 0, - ) - }) - if err != nil { - return err - } - return innerErr -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package quarantine + +import ( + "fmt" + "os" + "strings" + "time" + + "github.com/google/uuid" + "golang.org/x/sys/unix" +) + +func setQuarantineAttr(f *os.File) error { + sc, err := f.SyscallConn() + if err != nil { + return err + } + + now := time.Now() + + // We uppercase the UUID to match what other applications on macOS do + id := strings.ToUpper(uuid.New().String()) + + // kLSQuarantineTypeOtherDownload; this matches what AirDrop sets when + // receiving a file. + quarantineType := "0001" + + // This format is under-documented, but the following links contain a + // reasonably comprehensive overview: + // https://eclecticlight.co/2020/10/29/quarantine-and-the-quarantine-flag/ + // https://nixhacker.com/security-protection-in-macos-1/ + // https://ilostmynotes.blogspot.com/2012/06/gatekeeper-xprotect-and-quarantine.html + attrData := fmt.Sprintf("%s;%x;%s;%s", + quarantineType, // quarantine value + now.Unix(), // time in hex + "Tailscale", // application + id, // UUID + ) + + var innerErr error + err = sc.Control(func(fd uintptr) { + innerErr = unix.Fsetxattr( + int(fd), + "com.apple.quarantine", // attr + []byte(attrData), + 0, + ) + }) + if err != nil { + return err + } + return innerErr +} diff --git a/util/quarantine/quarantine_default.go b/util/quarantine/quarantine_default.go index 65a14ed26..65954a4d2 100644 --- a/util/quarantine/quarantine_default.go +++ b/util/quarantine/quarantine_default.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !darwin && !windows - -package quarantine - -import ( - "os" -) - -func setQuarantineAttr(f *os.File) error { - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !darwin && !windows + +package quarantine + +import ( + "os" +) + +func setQuarantineAttr(f *os.File) error { + return nil +} diff --git a/util/quarantine/quarantine_windows.go b/util/quarantine/quarantine_windows.go index 3052c2c6d..6fdf4e699 100644 --- a/util/quarantine/quarantine_windows.go +++ b/util/quarantine/quarantine_windows.go @@ -1,29 +1,29 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package quarantine - -import ( - "os" - "strings" -) - -func setQuarantineAttr(f *os.File) error { - // Documentation on this can be found here: - // https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-fscc/6e3f7352-d11c-4d76-8c39-2516a9df36e8 - // - // Additional information can be found at: - // https://www.digital-detective.net/forensic-analysis-of-zone-identifier-stream/ - // https://bugzilla.mozilla.org/show_bug.cgi?id=1433179 - content := strings.Join([]string{ - "[ZoneTransfer]", - - // "URLZONE_INTERNET" - // https://docs.microsoft.com/en-us/previous-versions/windows/internet-explorer/ie-developer/platform-apis/ms537175(v=vs.85) - "ZoneId=3", - - // TODO(andrew): should/could we add ReferrerUrl or HostUrl? - }, "\r\n") - - return os.WriteFile(f.Name()+":Zone.Identifier", []byte(content), 0) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package quarantine + +import ( + "os" + "strings" +) + +func setQuarantineAttr(f *os.File) error { + // Documentation on this can be found here: + // https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-fscc/6e3f7352-d11c-4d76-8c39-2516a9df36e8 + // + // Additional information can be found at: + // https://www.digital-detective.net/forensic-analysis-of-zone-identifier-stream/ + // https://bugzilla.mozilla.org/show_bug.cgi?id=1433179 + content := strings.Join([]string{ + "[ZoneTransfer]", + + // "URLZONE_INTERNET" + // https://docs.microsoft.com/en-us/previous-versions/windows/internet-explorer/ie-developer/platform-apis/ms537175(v=vs.85) + "ZoneId=3", + + // TODO(andrew): should/could we add ReferrerUrl or HostUrl? + }, "\r\n") + + return os.WriteFile(f.Name()+":Zone.Identifier", []byte(content), 0) +} diff --git a/util/race/race_test.go b/util/race/race_test.go index 17ea76459..d38382712 100644 --- a/util/race/race_test.go +++ b/util/race/race_test.go @@ -1,99 +1,99 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package race - -import ( - "context" - "errors" - "testing" - "time" - - "tailscale.com/tstest" -) - -func TestRaceSuccess1(t *testing.T) { - tstest.ResourceCheck(t) - - const want = "success" - rh := New[string]( - 10*time.Second, - func(context.Context) (string, error) { - return want, nil - }, func(context.Context) (string, error) { - t.Fatal("should not be called") - return "", nil - }) - res, err := rh.Start(context.Background()) - if err != nil { - t.Fatal(err) - } - if res != want { - t.Errorf("got res=%q, want %q", res, want) - } -} - -func TestRaceRetry(t *testing.T) { - tstest.ResourceCheck(t) - - const want = "fallback" - rh := New[string]( - 10*time.Second, - func(context.Context) (string, error) { - return "", errors.New("some error") - }, func(context.Context) (string, error) { - return want, nil - }) - res, err := rh.Start(context.Background()) - if err != nil { - t.Fatal(err) - } - if res != want { - t.Errorf("got res=%q, want %q", res, want) - } -} - -func TestRaceTimeout(t *testing.T) { - tstest.ResourceCheck(t) - - const want = "fallback" - rh := New[string]( - 100*time.Millisecond, - func(ctx context.Context) (string, error) { - // Block forever - <-ctx.Done() - return "", ctx.Err() - }, func(context.Context) (string, error) { - return want, nil - }) - res, err := rh.Start(context.Background()) - if err != nil { - t.Fatal(err) - } - if res != want { - t.Errorf("got res=%q, want %q", res, want) - } -} - -func TestRaceError(t *testing.T) { - tstest.ResourceCheck(t) - - err1 := errors.New("error 1") - err2 := errors.New("error 2") - - rh := New[string]( - 100*time.Millisecond, - func(ctx context.Context) (string, error) { - return "", err1 - }, func(context.Context) (string, error) { - return "", err2 - }) - - _, err := rh.Start(context.Background()) - if !errors.Is(err, err1) { - t.Errorf("wanted err to contain err1; got %v", err) - } - if !errors.Is(err, err2) { - t.Errorf("wanted err to contain err2; got %v", err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package race + +import ( + "context" + "errors" + "testing" + "time" + + "tailscale.com/tstest" +) + +func TestRaceSuccess1(t *testing.T) { + tstest.ResourceCheck(t) + + const want = "success" + rh := New[string]( + 10*time.Second, + func(context.Context) (string, error) { + return want, nil + }, func(context.Context) (string, error) { + t.Fatal("should not be called") + return "", nil + }) + res, err := rh.Start(context.Background()) + if err != nil { + t.Fatal(err) + } + if res != want { + t.Errorf("got res=%q, want %q", res, want) + } +} + +func TestRaceRetry(t *testing.T) { + tstest.ResourceCheck(t) + + const want = "fallback" + rh := New[string]( + 10*time.Second, + func(context.Context) (string, error) { + return "", errors.New("some error") + }, func(context.Context) (string, error) { + return want, nil + }) + res, err := rh.Start(context.Background()) + if err != nil { + t.Fatal(err) + } + if res != want { + t.Errorf("got res=%q, want %q", res, want) + } +} + +func TestRaceTimeout(t *testing.T) { + tstest.ResourceCheck(t) + + const want = "fallback" + rh := New[string]( + 100*time.Millisecond, + func(ctx context.Context) (string, error) { + // Block forever + <-ctx.Done() + return "", ctx.Err() + }, func(context.Context) (string, error) { + return want, nil + }) + res, err := rh.Start(context.Background()) + if err != nil { + t.Fatal(err) + } + if res != want { + t.Errorf("got res=%q, want %q", res, want) + } +} + +func TestRaceError(t *testing.T) { + tstest.ResourceCheck(t) + + err1 := errors.New("error 1") + err2 := errors.New("error 2") + + rh := New[string]( + 100*time.Millisecond, + func(ctx context.Context) (string, error) { + return "", err1 + }, func(context.Context) (string, error) { + return "", err2 + }) + + _, err := rh.Start(context.Background()) + if !errors.Is(err, err1) { + t.Errorf("wanted err to contain err1; got %v", err) + } + if !errors.Is(err, err2) { + t.Errorf("wanted err to contain err2; got %v", err) + } +} diff --git a/util/racebuild/off.go b/util/racebuild/off.go index a0dba0f32..8f4fe998f 100644 --- a/util/racebuild/off.go +++ b/util/racebuild/off.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !race - -package racebuild - -const On = false +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !race + +package racebuild + +const On = false diff --git a/util/racebuild/on.go b/util/racebuild/on.go index c60bca2e6..69ae2bcae 100644 --- a/util/racebuild/on.go +++ b/util/racebuild/on.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build race - -package racebuild - -const On = true +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build race + +package racebuild + +const On = true diff --git a/util/racebuild/racebuild.go b/util/racebuild/racebuild.go index c1a43eb96..d061276cb 100644 --- a/util/racebuild/racebuild.go +++ b/util/racebuild/racebuild.go @@ -1,6 +1,6 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package racebuild exports a constant about whether the current binary -// was built with the race detector. -package racebuild +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package racebuild exports a constant about whether the current binary +// was built with the race detector. +package racebuild diff --git a/util/rands/rands.go b/util/rands/rands.go index dcd75c5f3..d83e1e558 100644 --- a/util/rands/rands.go +++ b/util/rands/rands.go @@ -1,25 +1,25 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package rands contains utility functions for randomness. -package rands - -import ( - crand "crypto/rand" - "encoding/hex" -) - -// HexString returns a string of n cryptographically random lowercase -// hex characters. -// -// That is, HexString(3) returns something like "0fc", containing 12 -// bits of randomness. -func HexString(n int) string { - nb := n / 2 - if n%2 == 1 { - nb++ - } - b := make([]byte, nb) - crand.Read(b) - return hex.EncodeToString(b)[:n] -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package rands contains utility functions for randomness. +package rands + +import ( + crand "crypto/rand" + "encoding/hex" +) + +// HexString returns a string of n cryptographically random lowercase +// hex characters. +// +// That is, HexString(3) returns something like "0fc", containing 12 +// bits of randomness. +func HexString(n int) string { + nb := n / 2 + if n%2 == 1 { + nb++ + } + b := make([]byte, nb) + crand.Read(b) + return hex.EncodeToString(b)[:n] +} diff --git a/util/rands/rands_test.go b/util/rands/rands_test.go index ec339f94b..5813f2bb4 100644 --- a/util/rands/rands_test.go +++ b/util/rands/rands_test.go @@ -1,15 +1,15 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package rands - -import "testing" - -func TestHexString(t *testing.T) { - for i := 0; i <= 8; i++ { - s := HexString(i) - if len(s) != i { - t.Errorf("HexString(%v) = %q; want len %v, not %v", i, s, i, len(s)) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rands + +import "testing" + +func TestHexString(t *testing.T) { + for i := 0; i <= 8; i++ { + s := HexString(i) + if len(s) != i { + t.Errorf("HexString(%v) = %q; want len %v, not %v", i, s, i, len(s)) + } + } +} diff --git a/util/set/handle.go b/util/set/handle.go index 61b4eb93d..471ceeba2 100644 --- a/util/set/handle.go +++ b/util/set/handle.go @@ -1,28 +1,28 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package set - -// HandleSet is a set of T. -// -// It is not safe for concurrent use. -type HandleSet[T any] map[Handle]T - -// Handle is an opaque comparable value that's used as the map key in a -// HandleSet. The only way to get one is to call HandleSet.Add. -type Handle struct { - v *byte -} - -// Add adds the element (map value) e to the set. -// -// It returns the handle (map key) with which e can be removed, using a map -// delete. -func (s *HandleSet[T]) Add(e T) Handle { - h := Handle{new(byte)} - if *s == nil { - *s = make(HandleSet[T]) - } - (*s)[h] = e - return h -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package set + +// HandleSet is a set of T. +// +// It is not safe for concurrent use. +type HandleSet[T any] map[Handle]T + +// Handle is an opaque comparable value that's used as the map key in a +// HandleSet. The only way to get one is to call HandleSet.Add. +type Handle struct { + v *byte +} + +// Add adds the element (map value) e to the set. +// +// It returns the handle (map key) with which e can be removed, using a map +// delete. +func (s *HandleSet[T]) Add(e T) Handle { + h := Handle{new(byte)} + if *s == nil { + *s = make(HandleSet[T]) + } + (*s)[h] = e + return h +} diff --git a/util/set/slice_test.go b/util/set/slice_test.go index ca57e52e8..9134c2962 100644 --- a/util/set/slice_test.go +++ b/util/set/slice_test.go @@ -1,56 +1,56 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package set - -import ( - "testing" - - qt "github.com/frankban/quicktest" -) - -func TestSliceSet(t *testing.T) { - c := qt.New(t) - - var ss Slice[int] - c.Check(len(ss.slice), qt.Equals, 0) - ss.Add(1) - c.Check(len(ss.slice), qt.Equals, 1) - c.Check(len(ss.set), qt.Equals, 0) - c.Check(ss.Contains(1), qt.Equals, true) - c.Check(ss.Contains(2), qt.Equals, false) - - ss.Add(1) - c.Check(len(ss.slice), qt.Equals, 1) - c.Check(len(ss.set), qt.Equals, 0) - - ss.Add(2) - ss.Add(3) - ss.Add(4) - ss.Add(5) - ss.Add(6) - ss.Add(7) - ss.Add(8) - c.Check(len(ss.slice), qt.Equals, 8) - c.Check(len(ss.set), qt.Equals, 0) - - ss.Add(9) - c.Check(len(ss.slice), qt.Equals, 9) - c.Check(len(ss.set), qt.Equals, 9) - - ss.Remove(4) - c.Check(len(ss.slice), qt.Equals, 8) - c.Check(len(ss.set), qt.Equals, 8) - c.Assert(ss.Contains(4), qt.IsFalse) - - // Ensure that the order of insertion is maintained - c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9}) - ss.Add(4) - c.Check(len(ss.slice), qt.Equals, 9) - c.Check(len(ss.set), qt.Equals, 9) - c.Assert(ss.Contains(4), qt.IsTrue) - c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9, 4}) - - ss.Add(1, 234, 556) - c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9, 4, 234, 556}) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package set + +import ( + "testing" + + qt "github.com/frankban/quicktest" +) + +func TestSliceSet(t *testing.T) { + c := qt.New(t) + + var ss Slice[int] + c.Check(len(ss.slice), qt.Equals, 0) + ss.Add(1) + c.Check(len(ss.slice), qt.Equals, 1) + c.Check(len(ss.set), qt.Equals, 0) + c.Check(ss.Contains(1), qt.Equals, true) + c.Check(ss.Contains(2), qt.Equals, false) + + ss.Add(1) + c.Check(len(ss.slice), qt.Equals, 1) + c.Check(len(ss.set), qt.Equals, 0) + + ss.Add(2) + ss.Add(3) + ss.Add(4) + ss.Add(5) + ss.Add(6) + ss.Add(7) + ss.Add(8) + c.Check(len(ss.slice), qt.Equals, 8) + c.Check(len(ss.set), qt.Equals, 0) + + ss.Add(9) + c.Check(len(ss.slice), qt.Equals, 9) + c.Check(len(ss.set), qt.Equals, 9) + + ss.Remove(4) + c.Check(len(ss.slice), qt.Equals, 8) + c.Check(len(ss.set), qt.Equals, 8) + c.Assert(ss.Contains(4), qt.IsFalse) + + // Ensure that the order of insertion is maintained + c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9}) + ss.Add(4) + c.Check(len(ss.slice), qt.Equals, 9) + c.Check(len(ss.set), qt.Equals, 9) + c.Assert(ss.Contains(4), qt.IsTrue) + c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9, 4}) + + ss.Add(1, 234, 556) + c.Assert(ss.Slice().AsSlice(), qt.DeepEquals, []int{1, 2, 3, 5, 6, 7, 8, 9, 4, 234, 556}) +} diff --git a/util/sysresources/memory.go b/util/sysresources/memory.go index 8bf784e13..7363155cd 100644 --- a/util/sysresources/memory.go +++ b/util/sysresources/memory.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package sysresources - -// TotalMemory returns the total accessible system memory, in bytes. If the -// value cannot be determined, then 0 will be returned. -func TotalMemory() uint64 { - return totalMemoryImpl() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package sysresources + +// TotalMemory returns the total accessible system memory, in bytes. If the +// value cannot be determined, then 0 will be returned. +func TotalMemory() uint64 { + return totalMemoryImpl() +} diff --git a/util/sysresources/memory_bsd.go b/util/sysresources/memory_bsd.go index 39d3a18a9..26850dce6 100644 --- a/util/sysresources/memory_bsd.go +++ b/util/sysresources/memory_bsd.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build freebsd || openbsd || dragonfly || netbsd - -package sysresources - -import "golang.org/x/sys/unix" - -func totalMemoryImpl() uint64 { - val, err := unix.SysctlUint64("hw.physmem") - if err != nil { - return 0 - } - return val -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build freebsd || openbsd || dragonfly || netbsd + +package sysresources + +import "golang.org/x/sys/unix" + +func totalMemoryImpl() uint64 { + val, err := unix.SysctlUint64("hw.physmem") + if err != nil { + return 0 + } + return val +} diff --git a/util/sysresources/memory_darwin.go b/util/sysresources/memory_darwin.go index 2f74b6cec..e07bac0cd 100644 --- a/util/sysresources/memory_darwin.go +++ b/util/sysresources/memory_darwin.go @@ -1,16 +1,16 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin - -package sysresources - -import "golang.org/x/sys/unix" - -func totalMemoryImpl() uint64 { - val, err := unix.SysctlUint64("hw.memsize") - if err != nil { - return 0 - } - return val -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin + +package sysresources + +import "golang.org/x/sys/unix" + +func totalMemoryImpl() uint64 { + val, err := unix.SysctlUint64("hw.memsize") + if err != nil { + return 0 + } + return val +} diff --git a/util/sysresources/memory_linux.go b/util/sysresources/memory_linux.go index f3c51469f..0239b0e80 100644 --- a/util/sysresources/memory_linux.go +++ b/util/sysresources/memory_linux.go @@ -1,19 +1,19 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package sysresources - -import "golang.org/x/sys/unix" - -func totalMemoryImpl() uint64 { - var info unix.Sysinfo_t - - if err := unix.Sysinfo(&info); err != nil { - return 0 - } - - // uint64 casts are required since these might be uint32s - return uint64(info.Totalram) * uint64(info.Unit) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package sysresources + +import "golang.org/x/sys/unix" + +func totalMemoryImpl() uint64 { + var info unix.Sysinfo_t + + if err := unix.Sysinfo(&info); err != nil { + return 0 + } + + // uint64 casts are required since these might be uint32s + return uint64(info.Totalram) * uint64(info.Unit) +} diff --git a/util/sysresources/memory_unsupported.go b/util/sysresources/memory_unsupported.go index f80ef4e6e..0fde256e0 100644 --- a/util/sysresources/memory_unsupported.go +++ b/util/sysresources/memory_unsupported.go @@ -1,8 +1,8 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !(linux || darwin || freebsd || openbsd || dragonfly || netbsd) - -package sysresources - -func totalMemoryImpl() uint64 { return 0 } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !(linux || darwin || freebsd || openbsd || dragonfly || netbsd) + +package sysresources + +func totalMemoryImpl() uint64 { return 0 } diff --git a/util/sysresources/sysresources.go b/util/sysresources/sysresources.go index 1cce164a7..32d972ab1 100644 --- a/util/sysresources/sysresources.go +++ b/util/sysresources/sysresources.go @@ -1,6 +1,6 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package sysresources provides OS-independent methods of determining the -// resources available to the current system. -package sysresources +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package sysresources provides OS-independent methods of determining the +// resources available to the current system. +package sysresources diff --git a/util/sysresources/sysresources_test.go b/util/sysresources/sysresources_test.go index af9662042..331ad913b 100644 --- a/util/sysresources/sysresources_test.go +++ b/util/sysresources/sysresources_test.go @@ -1,25 +1,25 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package sysresources - -import ( - "runtime" - "testing" -) - -func TestTotalMemory(t *testing.T) { - switch runtime.GOOS { - case "linux": - case "freebsd", "openbsd", "dragonfly", "netbsd": - case "darwin": - default: - t.Skipf("not supported on runtime.GOOS=%q yet", runtime.GOOS) - } - - mem := TotalMemory() - if mem == 0 { - t.Fatal("wanted TotalMemory > 0") - } - t.Logf("total memory: %v bytes", mem) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package sysresources + +import ( + "runtime" + "testing" +) + +func TestTotalMemory(t *testing.T) { + switch runtime.GOOS { + case "linux": + case "freebsd", "openbsd", "dragonfly", "netbsd": + case "darwin": + default: + t.Skipf("not supported on runtime.GOOS=%q yet", runtime.GOOS) + } + + mem := TotalMemory() + if mem == 0 { + t.Fatal("wanted TotalMemory > 0") + } + t.Logf("total memory: %v bytes", mem) +} diff --git a/util/systemd/doc.go b/util/systemd/doc.go index 296f74e9d..0c28e1823 100644 --- a/util/systemd/doc.go +++ b/util/systemd/doc.go @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -/* -Package systemd contains a minimal wrapper around systemd-notify to enable -applications to signal readiness and status to systemd. - -This package will only have effect on Linux systems running Tailscale in a -systemd unit with the Type=notify flag set. On other operating systems (or -when running in a Linux distro without being run from inside systemd) this -package will become a no-op. -*/ -package systemd +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +/* +Package systemd contains a minimal wrapper around systemd-notify to enable +applications to signal readiness and status to systemd. + +This package will only have effect on Linux systems running Tailscale in a +systemd unit with the Type=notify flag set. On other operating systems (or +when running in a Linux distro without being run from inside systemd) this +package will become a no-op. +*/ +package systemd diff --git a/util/systemd/systemd_linux.go b/util/systemd/systemd_linux.go index 34d6daff3..909cfcb20 100644 --- a/util/systemd/systemd_linux.go +++ b/util/systemd/systemd_linux.go @@ -1,77 +1,77 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package systemd - -import ( - "errors" - "log" - "os" - "sync" - - "github.com/mdlayher/sdnotify" -) - -var getNotifyOnce struct { - sync.Once - v *sdnotify.Notifier -} - -type logOnce struct { - sync.Once -} - -func (l *logOnce) logf(format string, args ...any) { - l.Once.Do(func() { - log.Printf(format, args...) - }) -} - -var ( - readyOnce = &logOnce{} - statusOnce = &logOnce{} -) - -func notifier() *sdnotify.Notifier { - getNotifyOnce.Do(func() { - var err error - getNotifyOnce.v, err = sdnotify.New() - // Not exist means probably not running under systemd, so don't log. - if err != nil && !errors.Is(err, os.ErrNotExist) { - log.Printf("systemd: systemd-notifier error: %v", err) - } - }) - return getNotifyOnce.v -} - -// Ready signals readiness to systemd. This will unblock service dependents from starting. -func Ready() { - err := notifier().Notify(sdnotify.Ready) - if err != nil { - readyOnce.logf("systemd: error notifying: %v", err) - } -} - -// Status sends a single line status update to systemd so that information shows up -// in systemctl output. For example: -// -// $ systemctl status tailscale -// ● tailscale.service - Tailscale client daemon -// Loaded: loaded (/nix/store/qc312qcy907wz80fqrgbbm8a9djafmlg-unit-tailscale.service/tailscale.service; enabled; vendor preset: enabled) -// Active: active (running) since Tue 2020-11-24 17:54:07 EST; 13h ago -// Main PID: 26741 (.tailscaled-wra) -// Status: "Connected; user@host.domain.tld; 100.101.102.103" -// IP: 0B in, 0B out -// Tasks: 22 (limit: 4915) -// Memory: 30.9M -// CPU: 2min 38.469s -// CGroup: /system.slice/tailscale.service -// └─26741 /nix/store/sv6cj4mw2jajm9xkbwj07k29dj30lh0n-tailscale-date.20200727/bin/tailscaled --port 41641 -func Status(format string, args ...any) { - err := notifier().Notify(sdnotify.Statusf(format, args...)) - if err != nil { - statusOnce.logf("systemd: error notifying: %v", err) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package systemd + +import ( + "errors" + "log" + "os" + "sync" + + "github.com/mdlayher/sdnotify" +) + +var getNotifyOnce struct { + sync.Once + v *sdnotify.Notifier +} + +type logOnce struct { + sync.Once +} + +func (l *logOnce) logf(format string, args ...any) { + l.Once.Do(func() { + log.Printf(format, args...) + }) +} + +var ( + readyOnce = &logOnce{} + statusOnce = &logOnce{} +) + +func notifier() *sdnotify.Notifier { + getNotifyOnce.Do(func() { + var err error + getNotifyOnce.v, err = sdnotify.New() + // Not exist means probably not running under systemd, so don't log. + if err != nil && !errors.Is(err, os.ErrNotExist) { + log.Printf("systemd: systemd-notifier error: %v", err) + } + }) + return getNotifyOnce.v +} + +// Ready signals readiness to systemd. This will unblock service dependents from starting. +func Ready() { + err := notifier().Notify(sdnotify.Ready) + if err != nil { + readyOnce.logf("systemd: error notifying: %v", err) + } +} + +// Status sends a single line status update to systemd so that information shows up +// in systemctl output. For example: +// +// $ systemctl status tailscale +// ● tailscale.service - Tailscale client daemon +// Loaded: loaded (/nix/store/qc312qcy907wz80fqrgbbm8a9djafmlg-unit-tailscale.service/tailscale.service; enabled; vendor preset: enabled) +// Active: active (running) since Tue 2020-11-24 17:54:07 EST; 13h ago +// Main PID: 26741 (.tailscaled-wra) +// Status: "Connected; user@host.domain.tld; 100.101.102.103" +// IP: 0B in, 0B out +// Tasks: 22 (limit: 4915) +// Memory: 30.9M +// CPU: 2min 38.469s +// CGroup: /system.slice/tailscale.service +// └─26741 /nix/store/sv6cj4mw2jajm9xkbwj07k29dj30lh0n-tailscale-date.20200727/bin/tailscaled --port 41641 +func Status(format string, args ...any) { + err := notifier().Notify(sdnotify.Statusf(format, args...)) + if err != nil { + statusOnce.logf("systemd: error notifying: %v", err) + } +} diff --git a/util/systemd/systemd_nonlinux.go b/util/systemd/systemd_nonlinux.go index d8b20665f..36214020c 100644 --- a/util/systemd/systemd_nonlinux.go +++ b/util/systemd/systemd_nonlinux.go @@ -1,9 +1,9 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux - -package systemd - -func Ready() {} -func Status(string, ...any) {} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package systemd + +func Ready() {} +func Status(string, ...any) {} diff --git a/util/testenv/testenv.go b/util/testenv/testenv.go index 02c688803..12ada9003 100644 --- a/util/testenv/testenv.go +++ b/util/testenv/testenv.go @@ -1,21 +1,21 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package testenv provides utility functions for tests. It does not depend on -// the `testing` package to allow usage in non-test code. -package testenv - -import ( - "flag" - - "tailscale.com/types/lazy" -) - -var lazyInTest lazy.SyncValue[bool] - -// InTest reports whether the current binary is a test binary. -func InTest() bool { - return lazyInTest.Get(func() bool { - return flag.Lookup("test.v") != nil - }) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package testenv provides utility functions for tests. It does not depend on +// the `testing` package to allow usage in non-test code. +package testenv + +import ( + "flag" + + "tailscale.com/types/lazy" +) + +var lazyInTest lazy.SyncValue[bool] + +// InTest reports whether the current binary is a test binary. +func InTest() bool { + return lazyInTest.Get(func() bool { + return flag.Lookup("test.v") != nil + }) +} diff --git a/util/truncate/truncate_test.go b/util/truncate/truncate_test.go index 6ead55a6a..c0d9e6e14 100644 --- a/util/truncate/truncate_test.go +++ b/util/truncate/truncate_test.go @@ -1,36 +1,36 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package truncate_test - -import ( - "testing" - - "tailscale.com/util/truncate" -) - -func TestString(t *testing.T) { - tests := []struct { - input string - size int - want string - }{ - {"", 1000, ""}, // n > length - {"abc", 4, "abc"}, // n > length - {"abc", 3, "abc"}, // n == length - {"abcdefg", 4, "abcd"}, // n < length, safe - {"abcdefg", 0, ""}, // n < length, safe - {"abc\U0001fc2d", 3, "abc"}, // n < length, at boundary - {"abc\U0001fc2d", 4, "abc"}, // n < length, mid-rune - {"abc\U0001fc2d", 5, "abc"}, // n < length, mid-rune - {"abc\U0001fc2d", 6, "abc"}, // n < length, mid-rune - {"abc\U0001fc2defg", 7, "abc"}, // n < length, cut multibyte - } - - for _, tc := range tests { - got := truncate.String(tc.input, tc.size) - if got != tc.want { - t.Errorf("truncate(%q, %d): got %q, want %q", tc.input, tc.size, got, tc.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package truncate_test + +import ( + "testing" + + "tailscale.com/util/truncate" +) + +func TestString(t *testing.T) { + tests := []struct { + input string + size int + want string + }{ + {"", 1000, ""}, // n > length + {"abc", 4, "abc"}, // n > length + {"abc", 3, "abc"}, // n == length + {"abcdefg", 4, "abcd"}, // n < length, safe + {"abcdefg", 0, ""}, // n < length, safe + {"abc\U0001fc2d", 3, "abc"}, // n < length, at boundary + {"abc\U0001fc2d", 4, "abc"}, // n < length, mid-rune + {"abc\U0001fc2d", 5, "abc"}, // n < length, mid-rune + {"abc\U0001fc2d", 6, "abc"}, // n < length, mid-rune + {"abc\U0001fc2defg", 7, "abc"}, // n < length, cut multibyte + } + + for _, tc := range tests { + got := truncate.String(tc.input, tc.size) + if got != tc.want { + t.Errorf("truncate(%q, %d): got %q, want %q", tc.input, tc.size, got, tc.want) + } + } +} diff --git a/util/uniq/slice.go b/util/uniq/slice.go index fb46cc491..4ab933a9d 100644 --- a/util/uniq/slice.go +++ b/util/uniq/slice.go @@ -1,62 +1,62 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package uniq provides removal of adjacent duplicate elements in slices. -// It is similar to the unix command uniq. -package uniq - -// ModifySlice removes adjacent duplicate elements from the given slice. It -// adjusts the length of the slice appropriately and zeros the tail. -// -// ModifySlice does O(len(*slice)) operations. -func ModifySlice[E comparable](slice *[]E) { - // Remove duplicates - dst := 0 - for i := 1; i < len(*slice); i++ { - if (*slice)[i] == (*slice)[dst] { - continue - } - dst++ - (*slice)[dst] = (*slice)[i] - } - - // Zero out the elements we removed at the end of the slice - end := dst + 1 - var zero E - for i := end; i < len(*slice); i++ { - (*slice)[i] = zero - } - - // Truncate the slice - if end < len(*slice) { - *slice = (*slice)[:end] - } -} - -// ModifySliceFunc is the same as ModifySlice except that it allows using a -// custom comparison function. -// -// eq should report whether the two provided elements are equal. -func ModifySliceFunc[E any](slice *[]E, eq func(i, j E) bool) { - // Remove duplicates - dst := 0 - for i := 1; i < len(*slice); i++ { - if eq((*slice)[dst], (*slice)[i]) { - continue - } - dst++ - (*slice)[dst] = (*slice)[i] - } - - // Zero out the elements we removed at the end of the slice - end := dst + 1 - var zero E - for i := end; i < len(*slice); i++ { - (*slice)[i] = zero - } - - // Truncate the slice - if end < len(*slice) { - *slice = (*slice)[:end] - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package uniq provides removal of adjacent duplicate elements in slices. +// It is similar to the unix command uniq. +package uniq + +// ModifySlice removes adjacent duplicate elements from the given slice. It +// adjusts the length of the slice appropriately and zeros the tail. +// +// ModifySlice does O(len(*slice)) operations. +func ModifySlice[E comparable](slice *[]E) { + // Remove duplicates + dst := 0 + for i := 1; i < len(*slice); i++ { + if (*slice)[i] == (*slice)[dst] { + continue + } + dst++ + (*slice)[dst] = (*slice)[i] + } + + // Zero out the elements we removed at the end of the slice + end := dst + 1 + var zero E + for i := end; i < len(*slice); i++ { + (*slice)[i] = zero + } + + // Truncate the slice + if end < len(*slice) { + *slice = (*slice)[:end] + } +} + +// ModifySliceFunc is the same as ModifySlice except that it allows using a +// custom comparison function. +// +// eq should report whether the two provided elements are equal. +func ModifySliceFunc[E any](slice *[]E, eq func(i, j E) bool) { + // Remove duplicates + dst := 0 + for i := 1; i < len(*slice); i++ { + if eq((*slice)[dst], (*slice)[i]) { + continue + } + dst++ + (*slice)[dst] = (*slice)[i] + } + + // Zero out the elements we removed at the end of the slice + end := dst + 1 + var zero E + for i := end; i < len(*slice); i++ { + (*slice)[i] = zero + } + + // Truncate the slice + if end < len(*slice) { + *slice = (*slice)[:end] + } +} diff --git a/util/winutil/authenticode/mksyscall.go b/util/winutil/authenticode/mksyscall.go index 7c6b33973..8b7cabe6e 100644 --- a/util/winutil/authenticode/mksyscall.go +++ b/util/winutil/authenticode/mksyscall.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package authenticode - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go -//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go - -//sys cryptCATAdminAcquireContext2(hCatAdmin *_HCATADMIN, pgSubsystem *windows.GUID, hashAlgorithm *uint16, strongHashPolicy *windows.CertStrongSignPara, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminAcquireContext2 -//sys cryptCATAdminCalcHashFromFileHandle2(hCatAdmin _HCATADMIN, file windows.Handle, pcbHash *uint32, pbHash *byte, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminCalcHashFromFileHandle2 -//sys cryptCATAdminCatalogInfoFromContext(hCatInfo _HCATINFO, catInfo *_CATALOG_INFO, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATCatalogInfoFromContext -//sys cryptCATAdminEnumCatalogFromHash(hCatAdmin _HCATADMIN, pbHash *byte, cbHash uint32, flags uint32, prevCatInfo *_HCATINFO) (ret _HCATINFO, err error) [ret==0] = wintrust.CryptCATAdminEnumCatalogFromHash -//sys cryptCATAdminReleaseCatalogContext(hCatAdmin _HCATADMIN, hCatInfo _HCATINFO, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminReleaseCatalogContext -//sys cryptCATAdminReleaseContext(hCatAdmin _HCATADMIN, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminReleaseContext -//sys cryptMsgClose(cryptMsg windows.Handle) (err error) [int32(failretval)==0] = crypt32.CryptMsgClose -//sys cryptMsgGetParam(cryptMsg windows.Handle, paramType uint32, index uint32, data unsafe.Pointer, dataLen *uint32) (err error) [int32(failretval)==0] = crypt32.CryptMsgGetParam -//sys cryptVerifyMessageSignature(pVerifyPara *_CRYPT_VERIFY_MESSAGE_PARA, signerIndex uint32, pbSignedBlob *byte, cbSignedBlob uint32, pbDecoded *byte, pdbDecoded *uint32, ppSignerCert **windows.CertContext) (err error) [int32(failretval)==0] = crypt32.CryptVerifyMessageSignature -//sys msiGetFileSignatureInformation(signedObjectPath *uint16, flags uint32, certCtx **windows.CertContext, pbHashData *byte, cbHashData *uint32) (ret wingoes.HRESULT) = msi.MsiGetFileSignatureInformationW +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package authenticode + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go +//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go + +//sys cryptCATAdminAcquireContext2(hCatAdmin *_HCATADMIN, pgSubsystem *windows.GUID, hashAlgorithm *uint16, strongHashPolicy *windows.CertStrongSignPara, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminAcquireContext2 +//sys cryptCATAdminCalcHashFromFileHandle2(hCatAdmin _HCATADMIN, file windows.Handle, pcbHash *uint32, pbHash *byte, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminCalcHashFromFileHandle2 +//sys cryptCATAdminCatalogInfoFromContext(hCatInfo _HCATINFO, catInfo *_CATALOG_INFO, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATCatalogInfoFromContext +//sys cryptCATAdminEnumCatalogFromHash(hCatAdmin _HCATADMIN, pbHash *byte, cbHash uint32, flags uint32, prevCatInfo *_HCATINFO) (ret _HCATINFO, err error) [ret==0] = wintrust.CryptCATAdminEnumCatalogFromHash +//sys cryptCATAdminReleaseCatalogContext(hCatAdmin _HCATADMIN, hCatInfo _HCATINFO, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminReleaseCatalogContext +//sys cryptCATAdminReleaseContext(hCatAdmin _HCATADMIN, flags uint32) (err error) [int32(failretval)==0] = wintrust.CryptCATAdminReleaseContext +//sys cryptMsgClose(cryptMsg windows.Handle) (err error) [int32(failretval)==0] = crypt32.CryptMsgClose +//sys cryptMsgGetParam(cryptMsg windows.Handle, paramType uint32, index uint32, data unsafe.Pointer, dataLen *uint32) (err error) [int32(failretval)==0] = crypt32.CryptMsgGetParam +//sys cryptVerifyMessageSignature(pVerifyPara *_CRYPT_VERIFY_MESSAGE_PARA, signerIndex uint32, pbSignedBlob *byte, cbSignedBlob uint32, pbDecoded *byte, pdbDecoded *uint32, ppSignerCert **windows.CertContext) (err error) [int32(failretval)==0] = crypt32.CryptVerifyMessageSignature +//sys msiGetFileSignatureInformation(signedObjectPath *uint16, flags uint32, certCtx **windows.CertContext, pbHashData *byte, cbHashData *uint32) (ret wingoes.HRESULT) = msi.MsiGetFileSignatureInformationW diff --git a/util/winutil/policy/policy_windows.go b/util/winutil/policy/policy_windows.go index 4674696fa..89142951f 100644 --- a/util/winutil/policy/policy_windows.go +++ b/util/winutil/policy/policy_windows.go @@ -1,155 +1,155 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package policy contains higher-level abstractions for accessing Windows enterprise policies. -package policy - -import ( - "time" - - "tailscale.com/util/winutil" -) - -// PreferenceOptionPolicy is a policy that governs whether a boolean variable -// is forcibly assigned an administrator-defined value, or allowed to receive -// a user-defined value. -type PreferenceOptionPolicy int - -const ( - showChoiceByPolicy PreferenceOptionPolicy = iota - neverByPolicy - alwaysByPolicy -) - -// Show returns if the UI option that controls the choice administered by this -// policy should be shown. Currently this is true if and only if the policy is -// showChoiceByPolicy. -func (p PreferenceOptionPolicy) Show() bool { - return p == showChoiceByPolicy -} - -// ShouldEnable checks if the choice administered by this policy should be -// enabled. If the administrator has chosen a setting, the administrator's -// setting is returned, otherwise userChoice is returned. -func (p PreferenceOptionPolicy) ShouldEnable(userChoice bool) bool { - switch p { - case neverByPolicy: - return false - case alwaysByPolicy: - return true - default: - return userChoice - } -} - -// GetPreferenceOptionPolicy loads a policy from the registry that can be -// managed by an enterprise policy management system and allows administrative -// overrides of users' choices in a way that we do not want tailcontrol to have -// the authority to set. It describes user-decides/always/never options, where -// "always" and "never" remove the user's ability to make a selection. If not -// present or set to a different value, "user-decides" is the default. -func GetPreferenceOptionPolicy(name string) PreferenceOptionPolicy { - opt, err := winutil.GetPolicyString(name) - if opt == "" || err != nil { - return showChoiceByPolicy - } - switch opt { - case "always": - return alwaysByPolicy - case "never": - return neverByPolicy - default: - return showChoiceByPolicy - } -} - -// VisibilityPolicy is a policy that controls whether or not a particular -// component of a user interface is to be shown. -type VisibilityPolicy byte - -const ( - visibleByPolicy VisibilityPolicy = 'v' - hiddenByPolicy VisibilityPolicy = 'h' -) - -// Show reports whether the UI option administered by this policy should be shown. -// Currently this is true if and only if the policy is visibleByPolicy. -func (p VisibilityPolicy) Show() bool { - return p == visibleByPolicy -} - -// GetVisibilityPolicy loads a policy from the registry that can be managed -// by an enterprise policy management system and describes show/hide decisions -// for UI elements. The registry value should be a string set to "show" (return -// true) or "hide" (return true). If not present or set to a different value, -// "show" (return false) is the default. -func GetVisibilityPolicy(name string) VisibilityPolicy { - opt, err := winutil.GetPolicyString(name) - if opt == "" || err != nil { - return visibleByPolicy - } - switch opt { - case "hide": - return hiddenByPolicy - default: - return visibleByPolicy - } -} - -// GetDurationPolicy loads a policy from the registry that can be managed -// by an enterprise policy management system and describes a duration for some -// action. The registry value should be a string that time.ParseDuration -// understands. If the registry value is "" or can not be processed, -// defaultValue is returned instead. -func GetDurationPolicy(name string, defaultValue time.Duration) time.Duration { - opt, err := winutil.GetPolicyString(name) - if opt == "" || err != nil { - return defaultValue - } - v, err := time.ParseDuration(opt) - if err != nil || v < 0 { - return defaultValue - } - return v -} - -// SelectControlURL returns the ControlURL to use based on a value in -// the registry (LoginURL) and the one on disk (in the GUI's -// prefs.conf). If both are empty, it returns a default value. (It -// always return a non-empty value) -// -// See https://github.com/tailscale/tailscale/issues/2798 for some background. -func SelectControlURL(reg, disk string) string { - const def = "https://controlplane.tailscale.com" - - // Prior to Dec 2020's commit 739b02e6, the installer - // wrote a LoginURL value of https://login.tailscale.com to the registry. - const oldRegDef = "https://login.tailscale.com" - - // If they have an explicit value in the registry, use it, - // unless it's an old default value from an old installer. - // Then we have to see which is better. - if reg != "" { - if reg != oldRegDef { - // Something explicit in the registry that we didn't - // set ourselves by the installer. - return reg - } - if disk == "" { - // Something in the registry is better than nothing on disk. - return reg - } - if disk != def && disk != oldRegDef { - // The value in the registry is the old - // default (login.tailscale.com) but the value - // on disk is neither our old nor new default - // value, so it must be some custom thing that - // the user cares about. Prefer the disk value. - return disk - } - } - if disk != "" { - return disk - } - return def -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package policy contains higher-level abstractions for accessing Windows enterprise policies. +package policy + +import ( + "time" + + "tailscale.com/util/winutil" +) + +// PreferenceOptionPolicy is a policy that governs whether a boolean variable +// is forcibly assigned an administrator-defined value, or allowed to receive +// a user-defined value. +type PreferenceOptionPolicy int + +const ( + showChoiceByPolicy PreferenceOptionPolicy = iota + neverByPolicy + alwaysByPolicy +) + +// Show returns if the UI option that controls the choice administered by this +// policy should be shown. Currently this is true if and only if the policy is +// showChoiceByPolicy. +func (p PreferenceOptionPolicy) Show() bool { + return p == showChoiceByPolicy +} + +// ShouldEnable checks if the choice administered by this policy should be +// enabled. If the administrator has chosen a setting, the administrator's +// setting is returned, otherwise userChoice is returned. +func (p PreferenceOptionPolicy) ShouldEnable(userChoice bool) bool { + switch p { + case neverByPolicy: + return false + case alwaysByPolicy: + return true + default: + return userChoice + } +} + +// GetPreferenceOptionPolicy loads a policy from the registry that can be +// managed by an enterprise policy management system and allows administrative +// overrides of users' choices in a way that we do not want tailcontrol to have +// the authority to set. It describes user-decides/always/never options, where +// "always" and "never" remove the user's ability to make a selection. If not +// present or set to a different value, "user-decides" is the default. +func GetPreferenceOptionPolicy(name string) PreferenceOptionPolicy { + opt, err := winutil.GetPolicyString(name) + if opt == "" || err != nil { + return showChoiceByPolicy + } + switch opt { + case "always": + return alwaysByPolicy + case "never": + return neverByPolicy + default: + return showChoiceByPolicy + } +} + +// VisibilityPolicy is a policy that controls whether or not a particular +// component of a user interface is to be shown. +type VisibilityPolicy byte + +const ( + visibleByPolicy VisibilityPolicy = 'v' + hiddenByPolicy VisibilityPolicy = 'h' +) + +// Show reports whether the UI option administered by this policy should be shown. +// Currently this is true if and only if the policy is visibleByPolicy. +func (p VisibilityPolicy) Show() bool { + return p == visibleByPolicy +} + +// GetVisibilityPolicy loads a policy from the registry that can be managed +// by an enterprise policy management system and describes show/hide decisions +// for UI elements. The registry value should be a string set to "show" (return +// true) or "hide" (return true). If not present or set to a different value, +// "show" (return false) is the default. +func GetVisibilityPolicy(name string) VisibilityPolicy { + opt, err := winutil.GetPolicyString(name) + if opt == "" || err != nil { + return visibleByPolicy + } + switch opt { + case "hide": + return hiddenByPolicy + default: + return visibleByPolicy + } +} + +// GetDurationPolicy loads a policy from the registry that can be managed +// by an enterprise policy management system and describes a duration for some +// action. The registry value should be a string that time.ParseDuration +// understands. If the registry value is "" or can not be processed, +// defaultValue is returned instead. +func GetDurationPolicy(name string, defaultValue time.Duration) time.Duration { + opt, err := winutil.GetPolicyString(name) + if opt == "" || err != nil { + return defaultValue + } + v, err := time.ParseDuration(opt) + if err != nil || v < 0 { + return defaultValue + } + return v +} + +// SelectControlURL returns the ControlURL to use based on a value in +// the registry (LoginURL) and the one on disk (in the GUI's +// prefs.conf). If both are empty, it returns a default value. (It +// always return a non-empty value) +// +// See https://github.com/tailscale/tailscale/issues/2798 for some background. +func SelectControlURL(reg, disk string) string { + const def = "https://controlplane.tailscale.com" + + // Prior to Dec 2020's commit 739b02e6, the installer + // wrote a LoginURL value of https://login.tailscale.com to the registry. + const oldRegDef = "https://login.tailscale.com" + + // If they have an explicit value in the registry, use it, + // unless it's an old default value from an old installer. + // Then we have to see which is better. + if reg != "" { + if reg != oldRegDef { + // Something explicit in the registry that we didn't + // set ourselves by the installer. + return reg + } + if disk == "" { + // Something in the registry is better than nothing on disk. + return reg + } + if disk != def && disk != oldRegDef { + // The value in the registry is the old + // default (login.tailscale.com) but the value + // on disk is neither our old nor new default + // value, so it must be some custom thing that + // the user cares about. Prefer the disk value. + return disk + } + } + if disk != "" { + return disk + } + return def +} diff --git a/util/winutil/policy/policy_windows_test.go b/util/winutil/policy/policy_windows_test.go index ebfd185de..cf2390c56 100644 --- a/util/winutil/policy/policy_windows_test.go +++ b/util/winutil/policy/policy_windows_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package policy - -import "testing" - -func TestSelectControlURL(t *testing.T) { - tests := []struct { - reg, disk, want string - }{ - // Modern default case. - {"", "", "https://controlplane.tailscale.com"}, - - // For a user who installed prior to Dec 2020, with - // stuff in their registry. - {"https://login.tailscale.com", "", "https://login.tailscale.com"}, - - // Ignore pre-Dec'20 LoginURL from installer if prefs - // prefs overridden manually to an on-prem control - // server. - {"https://login.tailscale.com", "http://on-prem", "http://on-prem"}, - - // Something unknown explicitly set in the registry always wins. - {"http://explicit-reg", "", "http://explicit-reg"}, - {"http://explicit-reg", "http://on-prem", "http://explicit-reg"}, - {"http://explicit-reg", "https://login.tailscale.com", "http://explicit-reg"}, - {"http://explicit-reg", "https://controlplane.tailscale.com", "http://explicit-reg"}, - - // If nothing in the registry, disk wins. - {"", "http://on-prem", "http://on-prem"}, - } - for _, tt := range tests { - if got := SelectControlURL(tt.reg, tt.disk); got != tt.want { - t.Errorf("(reg %q, disk %q) = %q; want %q", tt.reg, tt.disk, got, tt.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package policy + +import "testing" + +func TestSelectControlURL(t *testing.T) { + tests := []struct { + reg, disk, want string + }{ + // Modern default case. + {"", "", "https://controlplane.tailscale.com"}, + + // For a user who installed prior to Dec 2020, with + // stuff in their registry. + {"https://login.tailscale.com", "", "https://login.tailscale.com"}, + + // Ignore pre-Dec'20 LoginURL from installer if prefs + // prefs overridden manually to an on-prem control + // server. + {"https://login.tailscale.com", "http://on-prem", "http://on-prem"}, + + // Something unknown explicitly set in the registry always wins. + {"http://explicit-reg", "", "http://explicit-reg"}, + {"http://explicit-reg", "http://on-prem", "http://explicit-reg"}, + {"http://explicit-reg", "https://login.tailscale.com", "http://explicit-reg"}, + {"http://explicit-reg", "https://controlplane.tailscale.com", "http://explicit-reg"}, + + // If nothing in the registry, disk wins. + {"", "http://on-prem", "http://on-prem"}, + } + for _, tt := range tests { + if got := SelectControlURL(tt.reg, tt.disk); got != tt.want { + t.Errorf("(reg %q, disk %q) = %q; want %q", tt.reg, tt.disk, got, tt.want) + } + } +} diff --git a/version/.gitignore b/version/.gitignore index 8878450fa..58d19bfc2 100644 --- a/version/.gitignore +++ b/version/.gitignore @@ -1,10 +1,10 @@ -describe.txt -long.txt -short.txt -gitcommit.txt -extragitcommit.txt -version-info.sh -version.h -version.xcconfig -ver.go -version +describe.txt +long.txt +short.txt +gitcommit.txt +extragitcommit.txt +version-info.sh +version.h +version.xcconfig +ver.go +version diff --git a/version/cmdname.go b/version/cmdname.go index 9f85ef96d..51e065438 100644 --- a/version/cmdname.go +++ b/version/cmdname.go @@ -1,139 +1,139 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !ios - -package version - -import ( - "bytes" - "encoding/hex" - "errors" - "io" - "os" - "path" - "path/filepath" - "strings" -) - -// CmdName returns either the base name of the current binary -// using os.Executable. If os.Executable fails (it shouldn't), then -// "cmd" is returned. -func CmdName() string { - e, err := os.Executable() - if err != nil { - return "cmd" - } - return cmdName(e) -} - -func cmdName(exe string) string { - // fallbackName, the lowercase basename of the executable, is what we return if - // we can't find the Go module metadata embedded in the file. - fallbackName := filepath.Base(strings.TrimSuffix(strings.ToLower(exe), ".exe")) - - var ret string - info, err := findModuleInfo(exe) - if err != nil { - return fallbackName - } - // v is like: - // "path\ttailscale.com/cmd/tailscale\nmod\ttailscale.com\t(devel)\t\ndep\tgithub.com/apenwarr/fixconsole\tv0.0.0-20191012055117-5a9f6489cc29\th1:muXWUcay7DDy1/hEQWrYlBy+g0EuwT70sBHg65SeUc4=\ndep\tgithub.... - for _, line := range strings.Split(info, "\n") { - if goPkg, ok := strings.CutPrefix(line, "path\t"); ok { // like "tailscale.com/cmd/tailscale" - ret = path.Base(goPkg) // goPkg is always forward slashes; use path, not filepath - break - } - } - if strings.HasPrefix(ret, "wg") && fallbackName == "tailscale-ipn" { - // The tailscale-ipn.exe binary for internal build system packaging reasons - // has a path of "tailscale.io/win/wg64", "tailscale.io/win/wg32", etc. - // Ignore that name and use "tailscale-ipn" instead. - return fallbackName - } - if ret == "" { - return fallbackName - } - return ret -} - -// findModuleInfo returns the Go module info from the executable file. -func findModuleInfo(file string) (s string, err error) { - f, err := os.Open(file) - if err != nil { - return "", err - } - defer f.Close() - // Scan through f until we find infoStart. - buf := make([]byte, 65536) - start, err := findOffset(f, buf, infoStart) - if err != nil { - return "", err - } - start += int64(len(infoStart)) - // Seek to the end of infoStart and scan for infoEnd. - _, err = f.Seek(start, io.SeekStart) - if err != nil { - return "", err - } - end, err := findOffset(f, buf, infoEnd) - if err != nil { - return "", err - } - length := end - start - // As of Aug 2021, tailscaled's mod info was about 2k. - if length > int64(len(buf)) { - return "", errors.New("mod info too large") - } - // We have located modinfo. Read it into buf. - buf = buf[:length] - _, err = f.Seek(start, io.SeekStart) - if err != nil { - return "", err - } - _, err = io.ReadFull(f, buf) - if err != nil { - return "", err - } - return string(buf), nil -} - -// findOffset finds the absolute offset of needle in f, -// starting at f's current read position, -// using temporary buffer buf. -func findOffset(f *os.File, buf, needle []byte) (int64, error) { - for { - // Fill buf and look within it. - n, err := f.Read(buf) - if err != nil { - return -1, err - } - i := bytes.Index(buf[:n], needle) - if i < 0 { - // Not found. Rewind a little bit in case we happened to end halfway through needle. - rewind, err := f.Seek(int64(-len(needle)), io.SeekCurrent) - if err != nil { - return -1, err - } - // If we're at EOF and rewound exactly len(needle) bytes, return io.EOF. - _, err = f.ReadAt(buf[:1], rewind+int64(len(needle))) - if err == io.EOF { - return -1, err - } - continue - } - // Found! Figure out exactly where. - cur, err := f.Seek(0, io.SeekCurrent) - if err != nil { - return -1, err - } - return cur - int64(n) + int64(i), nil - } -} - -// These constants are taken from rsc.io/goversion. - -var ( - infoStart, _ = hex.DecodeString("3077af0c9274080241e1c107e6d618e6") - infoEnd, _ = hex.DecodeString("f932433186182072008242104116d8f2") -) +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios + +package version + +import ( + "bytes" + "encoding/hex" + "errors" + "io" + "os" + "path" + "path/filepath" + "strings" +) + +// CmdName returns either the base name of the current binary +// using os.Executable. If os.Executable fails (it shouldn't), then +// "cmd" is returned. +func CmdName() string { + e, err := os.Executable() + if err != nil { + return "cmd" + } + return cmdName(e) +} + +func cmdName(exe string) string { + // fallbackName, the lowercase basename of the executable, is what we return if + // we can't find the Go module metadata embedded in the file. + fallbackName := filepath.Base(strings.TrimSuffix(strings.ToLower(exe), ".exe")) + + var ret string + info, err := findModuleInfo(exe) + if err != nil { + return fallbackName + } + // v is like: + // "path\ttailscale.com/cmd/tailscale\nmod\ttailscale.com\t(devel)\t\ndep\tgithub.com/apenwarr/fixconsole\tv0.0.0-20191012055117-5a9f6489cc29\th1:muXWUcay7DDy1/hEQWrYlBy+g0EuwT70sBHg65SeUc4=\ndep\tgithub.... + for _, line := range strings.Split(info, "\n") { + if goPkg, ok := strings.CutPrefix(line, "path\t"); ok { // like "tailscale.com/cmd/tailscale" + ret = path.Base(goPkg) // goPkg is always forward slashes; use path, not filepath + break + } + } + if strings.HasPrefix(ret, "wg") && fallbackName == "tailscale-ipn" { + // The tailscale-ipn.exe binary for internal build system packaging reasons + // has a path of "tailscale.io/win/wg64", "tailscale.io/win/wg32", etc. + // Ignore that name and use "tailscale-ipn" instead. + return fallbackName + } + if ret == "" { + return fallbackName + } + return ret +} + +// findModuleInfo returns the Go module info from the executable file. +func findModuleInfo(file string) (s string, err error) { + f, err := os.Open(file) + if err != nil { + return "", err + } + defer f.Close() + // Scan through f until we find infoStart. + buf := make([]byte, 65536) + start, err := findOffset(f, buf, infoStart) + if err != nil { + return "", err + } + start += int64(len(infoStart)) + // Seek to the end of infoStart and scan for infoEnd. + _, err = f.Seek(start, io.SeekStart) + if err != nil { + return "", err + } + end, err := findOffset(f, buf, infoEnd) + if err != nil { + return "", err + } + length := end - start + // As of Aug 2021, tailscaled's mod info was about 2k. + if length > int64(len(buf)) { + return "", errors.New("mod info too large") + } + // We have located modinfo. Read it into buf. + buf = buf[:length] + _, err = f.Seek(start, io.SeekStart) + if err != nil { + return "", err + } + _, err = io.ReadFull(f, buf) + if err != nil { + return "", err + } + return string(buf), nil +} + +// findOffset finds the absolute offset of needle in f, +// starting at f's current read position, +// using temporary buffer buf. +func findOffset(f *os.File, buf, needle []byte) (int64, error) { + for { + // Fill buf and look within it. + n, err := f.Read(buf) + if err != nil { + return -1, err + } + i := bytes.Index(buf[:n], needle) + if i < 0 { + // Not found. Rewind a little bit in case we happened to end halfway through needle. + rewind, err := f.Seek(int64(-len(needle)), io.SeekCurrent) + if err != nil { + return -1, err + } + // If we're at EOF and rewound exactly len(needle) bytes, return io.EOF. + _, err = f.ReadAt(buf[:1], rewind+int64(len(needle))) + if err == io.EOF { + return -1, err + } + continue + } + // Found! Figure out exactly where. + cur, err := f.Seek(0, io.SeekCurrent) + if err != nil { + return -1, err + } + return cur - int64(n) + int64(i), nil + } +} + +// These constants are taken from rsc.io/goversion. + +var ( + infoStart, _ = hex.DecodeString("3077af0c9274080241e1c107e6d618e6") + infoEnd, _ = hex.DecodeString("f932433186182072008242104116d8f2") +) diff --git a/version/cmdname_ios.go b/version/cmdname_ios.go index 5e338944c..6bfed38b6 100644 --- a/version/cmdname_ios.go +++ b/version/cmdname_ios.go @@ -1,18 +1,18 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ios - -package version - -import ( - "os" -) - -func CmdName() string { - e, err := os.Executable() - if err != nil { - return "cmd" - } - return e -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ios + +package version + +import ( + "os" +) + +func CmdName() string { + e, err := os.Executable() + if err != nil { + return "cmd" + } + return e +} diff --git a/version/cmp_test.go b/version/cmp_test.go index 59153f0dd..e244d5e16 100644 --- a/version/cmp_test.go +++ b/version/cmp_test.go @@ -1,82 +1,82 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package version_test - -import ( - "testing" - - "github.com/google/go-cmp/cmp" - "tailscale.com/tstest" - "tailscale.com/version" -) - -func TestParse(t *testing.T) { - parse := version.ExportParse - type parsed = version.ExportParsed - - tests := []struct { - version string - parsed parsed - want bool - }{ - {"1", parsed{Major: 1}, true}, - {"1.2", parsed{Major: 1, Minor: 2}, true}, - {"1.2.3", parsed{Major: 1, Minor: 2, Patch: 3}, true}, - {"1.2.3-4", parsed{Major: 1, Minor: 2, Patch: 3, ExtraCommits: 4}, true}, - {"1.2-4", parsed{Major: 1, Minor: 2, ExtraCommits: 4}, true}, - {"1.2.3-4-extra", parsed{Major: 1, Minor: 2, Patch: 3, ExtraCommits: 4}, true}, - {"1.2.3-4a-test", parsed{Major: 1, Minor: 2, Patch: 3}, true}, - {"1.2-extra", parsed{Major: 1, Minor: 2}, true}, - {"1.2.3-extra", parsed{Major: 1, Minor: 2, Patch: 3}, true}, - {"date.20200612", parsed{Datestamp: 20200612}, true}, - {"borkbork", parsed{}, false}, - {"1a.2.3", parsed{}, false}, - {"", parsed{}, false}, - } - - for _, test := range tests { - gotParsed, got := parse(test.version) - if got != test.want { - t.Errorf("version(%q) = %v, want %v", test.version, got, test.want) - } - if diff := cmp.Diff(gotParsed, test.parsed); diff != "" { - t.Errorf("parse(%q) diff (-got+want):\n%s", test.version, diff) - } - err := tstest.MinAllocsPerRun(t, 0, func() { - gotParsed, got = parse(test.version) - }) - if err != nil { - t.Errorf("parse(%q): %v", test.version, err) - } - } -} - -func TestAtLeast(t *testing.T) { - tests := []struct { - v, m string - want bool - }{ - {"1", "1", true}, - {"1.2", "1", true}, - {"1.2.3", "1", true}, - {"1.2.3-4", "1", true}, - {"0.98-0", "0.98", true}, - {"0.97.1-216", "0.98", false}, - {"0.94", "0.98", false}, - {"0.98", "0.98", true}, - {"0.98.0-0", "0.98", true}, - {"1.2.3-4", "1.2.4-4", false}, - {"1.2.3-4", "1.2.3-4", true}, - {"date.20200612", "date.20200612", true}, - {"date.20200701", "date.20200612", true}, - {"date.20200501", "date.20200612", false}, - } - - for _, test := range tests { - got := version.AtLeast(test.v, test.m) - if got != test.want { - t.Errorf("AtLeast(%q, %q) = %v, want %v", test.v, test.m, got, test.want) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package version_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/tstest" + "tailscale.com/version" +) + +func TestParse(t *testing.T) { + parse := version.ExportParse + type parsed = version.ExportParsed + + tests := []struct { + version string + parsed parsed + want bool + }{ + {"1", parsed{Major: 1}, true}, + {"1.2", parsed{Major: 1, Minor: 2}, true}, + {"1.2.3", parsed{Major: 1, Minor: 2, Patch: 3}, true}, + {"1.2.3-4", parsed{Major: 1, Minor: 2, Patch: 3, ExtraCommits: 4}, true}, + {"1.2-4", parsed{Major: 1, Minor: 2, ExtraCommits: 4}, true}, + {"1.2.3-4-extra", parsed{Major: 1, Minor: 2, Patch: 3, ExtraCommits: 4}, true}, + {"1.2.3-4a-test", parsed{Major: 1, Minor: 2, Patch: 3}, true}, + {"1.2-extra", parsed{Major: 1, Minor: 2}, true}, + {"1.2.3-extra", parsed{Major: 1, Minor: 2, Patch: 3}, true}, + {"date.20200612", parsed{Datestamp: 20200612}, true}, + {"borkbork", parsed{}, false}, + {"1a.2.3", parsed{}, false}, + {"", parsed{}, false}, + } + + for _, test := range tests { + gotParsed, got := parse(test.version) + if got != test.want { + t.Errorf("version(%q) = %v, want %v", test.version, got, test.want) + } + if diff := cmp.Diff(gotParsed, test.parsed); diff != "" { + t.Errorf("parse(%q) diff (-got+want):\n%s", test.version, diff) + } + err := tstest.MinAllocsPerRun(t, 0, func() { + gotParsed, got = parse(test.version) + }) + if err != nil { + t.Errorf("parse(%q): %v", test.version, err) + } + } +} + +func TestAtLeast(t *testing.T) { + tests := []struct { + v, m string + want bool + }{ + {"1", "1", true}, + {"1.2", "1", true}, + {"1.2.3", "1", true}, + {"1.2.3-4", "1", true}, + {"0.98-0", "0.98", true}, + {"0.97.1-216", "0.98", false}, + {"0.94", "0.98", false}, + {"0.98", "0.98", true}, + {"0.98.0-0", "0.98", true}, + {"1.2.3-4", "1.2.4-4", false}, + {"1.2.3-4", "1.2.3-4", true}, + {"date.20200612", "date.20200612", true}, + {"date.20200701", "date.20200612", true}, + {"date.20200501", "date.20200612", false}, + } + + for _, test := range tests { + got := version.AtLeast(test.v, test.m) + if got != test.want { + t.Errorf("AtLeast(%q, %q) = %v, want %v", test.v, test.m, got, test.want) + } + } +} diff --git a/version/export_test.go b/version/export_test.go index fabba13e8..8e8ce5ecb 100644 --- a/version/export_test.go +++ b/version/export_test.go @@ -1,14 +1,14 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package version - -var ( - ExportParse = parse - ExportFindModuleInfo = findModuleInfo - ExportCmdName = cmdName -) - -type ( - ExportParsed = parsed -) +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package version + +var ( + ExportParse = parse + ExportFindModuleInfo = findModuleInfo + ExportCmdName = cmdName +) + +type ( + ExportParsed = parsed +) diff --git a/version/print.go b/version/print.go index e3bfc38ef..7d8554279 100644 --- a/version/print.go +++ b/version/print.go @@ -1,33 +1,33 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package version - -import ( - "fmt" - "runtime" - "strings" - - "tailscale.com/types/lazy" -) - -var stringLazy = lazy.SyncFunc(func() string { - var ret strings.Builder - ret.WriteString(Short()) - ret.WriteByte('\n') - if IsUnstableBuild() { - fmt.Fprintf(&ret, " track: unstable (dev); frequent updates and bugs are likely\n") - } - if gitCommit() != "" { - fmt.Fprintf(&ret, " tailscale commit: %s%s\n", gitCommit(), dirtyString()) - } - if extraGitCommitStamp != "" { - fmt.Fprintf(&ret, " other commit: %s\n", extraGitCommitStamp) - } - fmt.Fprintf(&ret, " go version: %s\n", runtime.Version()) - return strings.TrimSpace(ret.String()) -}) - -func String() string { - return stringLazy() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package version + +import ( + "fmt" + "runtime" + "strings" + + "tailscale.com/types/lazy" +) + +var stringLazy = lazy.SyncFunc(func() string { + var ret strings.Builder + ret.WriteString(Short()) + ret.WriteByte('\n') + if IsUnstableBuild() { + fmt.Fprintf(&ret, " track: unstable (dev); frequent updates and bugs are likely\n") + } + if gitCommit() != "" { + fmt.Fprintf(&ret, " tailscale commit: %s%s\n", gitCommit(), dirtyString()) + } + if extraGitCommitStamp != "" { + fmt.Fprintf(&ret, " other commit: %s\n", extraGitCommitStamp) + } + fmt.Fprintf(&ret, " go version: %s\n", runtime.Version()) + return strings.TrimSpace(ret.String()) +}) + +func String() string { + return stringLazy() +} diff --git a/version/race.go b/version/race.go index bc3ca8db6..e1dc76591 100644 --- a/version/race.go +++ b/version/race.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build race - -package version - -// IsRace reports whether the current binary was built with the Go -// race detector enabled. -func IsRace() bool { return true } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build race + +package version + +// IsRace reports whether the current binary was built with the Go +// race detector enabled. +func IsRace() bool { return true } diff --git a/version/race_off.go b/version/race_off.go index d55288d9c..6db901974 100644 --- a/version/race_off.go +++ b/version/race_off.go @@ -1,10 +1,10 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !race - -package version - -// IsRace reports whether the current binary was built with the Go -// race detector enabled. -func IsRace() bool { return false } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !race + +package version + +// IsRace reports whether the current binary was built with the Go +// race detector enabled. +func IsRace() bool { return false } diff --git a/version/version_test.go b/version/version_test.go index 4d676f9f5..a51565058 100644 --- a/version/version_test.go +++ b/version/version_test.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package version_test - -import ( - "bytes" - "os" - "testing" - - ts "tailscale.com" - "tailscale.com/version" -) - -func TestAlpineTag(t *testing.T) { - if tag := readAlpineTag(t, "../Dockerfile.base"); tag == "" { - t.Fatal(`"FROM alpine:" not found in Dockerfile.base`) - } else if tag != ts.AlpineDockerTag { - t.Errorf("alpine version mismatch: Dockerfile.base has %q; ALPINE.txt has %q", tag, ts.AlpineDockerTag) - } - if tag := readAlpineTag(t, "../Dockerfile"); tag == "" { - t.Fatal(`"FROM alpine:" not found in Dockerfile`) - } else if tag != ts.AlpineDockerTag { - t.Errorf("alpine version mismatch: Dockerfile has %q; ALPINE.txt has %q", tag, ts.AlpineDockerTag) - } -} - -func readAlpineTag(t *testing.T, file string) string { - f, err := os.ReadFile(file) - if err != nil { - t.Fatal(err) - } - for _, line := range bytes.Split(f, []byte{'\n'}) { - line = bytes.TrimSpace(line) - _, suf, ok := bytes.Cut(line, []byte("FROM alpine:")) - if !ok { - continue - } - return string(suf) - } - return "" -} - -func TestShortAllocs(t *testing.T) { - allocs := int(testing.AllocsPerRun(10000, func() { - _ = version.Short() - })) - if allocs > 0 { - t.Errorf("allocs = %v; want 0", allocs) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package version_test + +import ( + "bytes" + "os" + "testing" + + ts "tailscale.com" + "tailscale.com/version" +) + +func TestAlpineTag(t *testing.T) { + if tag := readAlpineTag(t, "../Dockerfile.base"); tag == "" { + t.Fatal(`"FROM alpine:" not found in Dockerfile.base`) + } else if tag != ts.AlpineDockerTag { + t.Errorf("alpine version mismatch: Dockerfile.base has %q; ALPINE.txt has %q", tag, ts.AlpineDockerTag) + } + if tag := readAlpineTag(t, "../Dockerfile"); tag == "" { + t.Fatal(`"FROM alpine:" not found in Dockerfile`) + } else if tag != ts.AlpineDockerTag { + t.Errorf("alpine version mismatch: Dockerfile has %q; ALPINE.txt has %q", tag, ts.AlpineDockerTag) + } +} + +func readAlpineTag(t *testing.T, file string) string { + f, err := os.ReadFile(file) + if err != nil { + t.Fatal(err) + } + for _, line := range bytes.Split(f, []byte{'\n'}) { + line = bytes.TrimSpace(line) + _, suf, ok := bytes.Cut(line, []byte("FROM alpine:")) + if !ok { + continue + } + return string(suf) + } + return "" +} + +func TestShortAllocs(t *testing.T) { + allocs := int(testing.AllocsPerRun(10000, func() { + _ = version.Short() + })) + if allocs > 0 { + t.Errorf("allocs = %v; want 0", allocs) + } +} diff --git a/wgengine/bench/bench.go b/wgengine/bench/bench.go index b94930ee5..8695f18d1 100644 --- a/wgengine/bench/bench.go +++ b/wgengine/bench/bench.go @@ -1,409 +1,409 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Create two wgengine instances and pass data through them, measuring -// throughput, latency, and packet loss. -package main - -import ( - "bufio" - "io" - "log" - "net" - "net/http" - "net/http/pprof" - "net/netip" - "os" - "strconv" - "sync" - "time" - - "tailscale.com/types/logger" -) - -const PayloadSize = 1000 -const ICMPMinSize = 24 - -var Addr1 = netip.MustParsePrefix("100.64.1.1/32") -var Addr2 = netip.MustParsePrefix("100.64.1.2/32") - -func main() { - var logf logger.Logf = log.Printf - log.SetFlags(0) - - debugMux := newDebugMux() - go runDebugServer(debugMux, "0.0.0.0:8999") - - mode, err := strconv.Atoi(os.Args[1]) - if err != nil { - log.Fatalf("%q: %v", os.Args[1], err) - } - - traf := NewTrafficGen(nil) - - // Sample test results below are using GOMAXPROCS=2 (for some - // tests, including wireguard-go, higher GOMAXPROCS goes slower) - // on apenwarr's old Linux box: - // Intel(R) Core(TM) i7-4785T CPU @ 2.20GHz - // My 2019 Mac Mini is about 20% faster on most tests. - - switch mode { - // tx=8786325 rx=8786326 (0 = 0.00% loss) (70768.7 Mbits/sec) - case 1: - setupTrivialNoAllocTest(logf, traf) - - // tx=6476293 rx=6476293 (0 = 0.00% loss) (52249.7 Mbits/sec) - case 2: - setupTrivialTest(logf, traf) - - // tx=1957974 rx=1958379 (0 = 0.00% loss) (15939.8 Mbits/sec) - case 11: - setupBlockingChannelTest(logf, traf) - - // tx=728621 rx=701825 (26620 = 3.65% loss) (5525.2 Mbits/sec) - // (much faster on macOS??) - case 12: - setupNonblockingChannelTest(logf, traf) - - // tx=1024260 rx=941098 (83334 = 8.14% loss) (7516.6 Mbits/sec) - // (much faster on macOS??) - case 13: - setupDoubleChannelTest(logf, traf) - - // tx=265468 rx=263189 (2279 = 0.86% loss) (2162.0 Mbits/sec) - case 21: - setupUDPTest(logf, traf) - - // tx=1493580 rx=1493580 (0 = 0.00% loss) (12210.4 Mbits/sec) - case 31: - setupBatchTCPTest(logf, traf) - - // tx=134236 rx=133166 (1070 = 0.80% loss) (1088.9 Mbits/sec) - case 101: - setupWGTest(nil, logf, traf, Addr1, Addr2) - - default: - log.Fatalf("provide a valid test number (0..n)") - } - - logf("initialized ok.") - traf.Start(Addr1.Addr(), Addr2.Addr(), PayloadSize+ICMPMinSize, 0) - - var cur, prev Snapshot - var pps int64 - i := 0 - for { - i += 1 - time.Sleep(10 * time.Millisecond) - - if (i % 100) == 0 { - prev = cur - cur = traf.Snap() - d := cur.Sub(prev) - - if prev.WhenNsec == 0 { - logf("tx=%-6d rx=%-6d", d.TxPackets, d.RxPackets) - } else { - logf("%v @%7d pkt/s", d, pps) - } - } - - pps = traf.Adjust() - } -} - -func newDebugMux() *http.ServeMux { - mux := http.NewServeMux() - mux.HandleFunc("/debug/pprof/", pprof.Index) - mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) - mux.HandleFunc("/debug/pprof/profile", pprof.Profile) - mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) - mux.HandleFunc("/debug/pprof/trace", pprof.Trace) - return mux -} - -func runDebugServer(mux *http.ServeMux, addr string) { - srv := &http.Server{ - Addr: addr, - Handler: mux, - } - if err := srv.ListenAndServe(); err != nil { - log.Fatal(err) - } -} - -// The absolute minimal test of the traffic generator: have it fill -// a packet buffer, then absorb it again. Zero packet loss. -func setupTrivialNoAllocTest(logf logger.Logf, traf *TrafficGen) { - go func() { - b := make([]byte, 1600) - for { - n := traf.Generate(b, 16) - if n == 0 { - break - } - traf.GotPacket(b[0:n+16], 16) - } - }() -} - -// Almost the same, but this time allocate a fresh buffer each time -// through the loop. Still zero packet loss. Runs about 2/3 as fast for me. -func setupTrivialTest(logf logger.Logf, traf *TrafficGen) { - go func() { - for { - b := make([]byte, 1600) - n := traf.Generate(b, 16) - if n == 0 { - break - } - traf.GotPacket(b[0:n+16], 16) - } - }() -} - -// Pass packets through a blocking channel between sender and receiver. -// Still zero packet loss since the sender stops when the channel is full. -// Max speed depends on channel length (I'm not sure why). -func setupBlockingChannelTest(logf logger.Logf, traf *TrafficGen) { - ch := make(chan []byte, 1000) - - go func() { - // transmitter - for { - b := make([]byte, 1600) - n := traf.Generate(b, 16) - if n == 0 { - close(ch) - break - } - ch <- b[0 : n+16] - } - }() - - go func() { - // receiver - for b := range ch { - traf.GotPacket(b, 16) - } - }() -} - -// Same as setupBlockingChannelTest, but now we drop packets whenever the -// channel is full. Max speed is about the same as the above test, but -// now with nonzero packet loss. -func setupNonblockingChannelTest(logf logger.Logf, traf *TrafficGen) { - ch := make(chan []byte, 1000) - - go func() { - // transmitter - for { - b := make([]byte, 1600) - n := traf.Generate(b, 16) - if n == 0 { - close(ch) - break - } - select { - case ch <- b[0 : n+16]: - default: - } - } - }() - - go func() { - // receiver - for b := range ch { - traf.GotPacket(b, 16) - } - }() -} - -// Same as above, but at an intermediate blocking channel and goroutine -// to make things a little more like wireguard-go. Roughly 20% slower than -// the single-channel version. -func setupDoubleChannelTest(logf logger.Logf, traf *TrafficGen) { - ch := make(chan []byte, 1000) - ch2 := make(chan []byte, 1000) - - go func() { - // transmitter - for { - b := make([]byte, 1600) - n := traf.Generate(b, 16) - if n == 0 { - close(ch) - break - } - select { - case ch <- b[0 : n+16]: - default: - } - } - }() - - go func() { - // intermediary - for b := range ch { - ch2 <- b - } - close(ch2) - }() - - go func() { - // receiver - for b := range ch2 { - traf.GotPacket(b, 16) - } - }() -} - -// Instead of a channel, pass packets through a UDP socket. -func setupUDPTest(logf logger.Logf, traf *TrafficGen) { - la, err := net.ResolveUDPAddr("udp", ":0") - if err != nil { - log.Fatalf("resolve: %v", err) - } - - s1, err := net.ListenUDP("udp", la) - if err != nil { - log.Fatalf("listen1: %v", err) - } - s2, err := net.ListenUDP("udp", la) - if err != nil { - log.Fatalf("listen2: %v", err) - } - - a2 := s2.LocalAddr() - - // On macOS (but not Linux), you can't transmit to 0.0.0.0:port, - // which is what returns from .LocalAddr() above. We have to - // force it to localhost instead. - a2.(*net.UDPAddr).IP = net.ParseIP("127.0.0.1") - - s1.SetWriteBuffer(1024 * 1024) - s2.SetReadBuffer(1024 * 1024) - - go func() { - // transmitter - b := make([]byte, 1600) - for { - n := traf.Generate(b, 16) - if n == 0 { - break - } - s1.WriteTo(b[16:n+16], a2) - } - }() - - go func() { - // receiver - b := make([]byte, 1600) - for traf.Running() { - // Use ReadFrom instead of Read, to be more like - // how wireguard-go does it, even though we're not - // going to actually look at the address. - n, _, err := s2.ReadFrom(b) - if err != nil { - log.Fatalf("s2.Read: %v", err) - } - traf.GotPacket(b[:n], 0) - } - }() -} - -// Instead of a channel, pass packets through a TCP socket. -// TCP is a single stream, so we can amortize one syscall across -// multiple packets. 10x amortization seems to make it go ~10x faster, -// as expected, getting us close to the speed of the channel tests above. -// There's also zero packet loss. -func setupBatchTCPTest(logf logger.Logf, traf *TrafficGen) { - sl, err := net.Listen("tcp", ":0") - if err != nil { - log.Fatalf("listen: %v", err) - } - - var slCloseOnce sync.Once - slClose := func() { - slCloseOnce.Do(func() { - sl.Close() - }) - } - - s1, err := net.Dial("tcp", sl.Addr().String()) - if err != nil { - log.Fatalf("dial: %v", err) - } - - s2, err := sl.Accept() - if err != nil { - log.Fatalf("accept: %v", err) - } - - s1.(*net.TCPConn).SetWriteBuffer(1024 * 1024) - s2.(*net.TCPConn).SetReadBuffer(1024 * 1024) - - ch := make(chan int) - - go func() { - // transmitter - defer slClose() - defer s1.Close() - - bs1 := bufio.NewWriterSize(s1, 1024*1024) - - b := make([]byte, 1600) - i := 0 - for { - i += 1 - n := traf.Generate(b, 16) - if n == 0 { - break - } - if i == 1 { - ch <- n - } - bs1.Write(b[16 : n+16]) - - // TODO: this is a pretty half-baked batching - // function, which we'd never want to employ in - // a real-life program. - // - // In real life, we'd probably want to flush - // immediately when there are no more packets to - // generate, and queue up only if we fall behind. - // - // In our case however, we just want to see the - // technical benefits of batching 10 syscalls - // into 1, so a fixed ratio makes more sense. - if (i % 10) == 0 { - bs1.Flush() - } - } - }() - - go func() { - // receiver - defer slClose() - defer s2.Close() - - bs2 := bufio.NewReaderSize(s2, 1024*1024) - - // Find out the packet size (we happen to know they're - // all the same size) - packetSize := <-ch - - b := make([]byte, packetSize) - for traf.Running() { - // TODO: can't use ReadFrom() here, which is - // unfair compared to UDP. (ReadFrom for UDP - // apparently allocates memory per packet, which - // this test does not.) - n, err := io.ReadFull(bs2, b) - if err != nil { - log.Fatalf("s2.Read: %v", err) - } - traf.GotPacket(b[:n], 0) - } - }() -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Create two wgengine instances and pass data through them, measuring +// throughput, latency, and packet loss. +package main + +import ( + "bufio" + "io" + "log" + "net" + "net/http" + "net/http/pprof" + "net/netip" + "os" + "strconv" + "sync" + "time" + + "tailscale.com/types/logger" +) + +const PayloadSize = 1000 +const ICMPMinSize = 24 + +var Addr1 = netip.MustParsePrefix("100.64.1.1/32") +var Addr2 = netip.MustParsePrefix("100.64.1.2/32") + +func main() { + var logf logger.Logf = log.Printf + log.SetFlags(0) + + debugMux := newDebugMux() + go runDebugServer(debugMux, "0.0.0.0:8999") + + mode, err := strconv.Atoi(os.Args[1]) + if err != nil { + log.Fatalf("%q: %v", os.Args[1], err) + } + + traf := NewTrafficGen(nil) + + // Sample test results below are using GOMAXPROCS=2 (for some + // tests, including wireguard-go, higher GOMAXPROCS goes slower) + // on apenwarr's old Linux box: + // Intel(R) Core(TM) i7-4785T CPU @ 2.20GHz + // My 2019 Mac Mini is about 20% faster on most tests. + + switch mode { + // tx=8786325 rx=8786326 (0 = 0.00% loss) (70768.7 Mbits/sec) + case 1: + setupTrivialNoAllocTest(logf, traf) + + // tx=6476293 rx=6476293 (0 = 0.00% loss) (52249.7 Mbits/sec) + case 2: + setupTrivialTest(logf, traf) + + // tx=1957974 rx=1958379 (0 = 0.00% loss) (15939.8 Mbits/sec) + case 11: + setupBlockingChannelTest(logf, traf) + + // tx=728621 rx=701825 (26620 = 3.65% loss) (5525.2 Mbits/sec) + // (much faster on macOS??) + case 12: + setupNonblockingChannelTest(logf, traf) + + // tx=1024260 rx=941098 (83334 = 8.14% loss) (7516.6 Mbits/sec) + // (much faster on macOS??) + case 13: + setupDoubleChannelTest(logf, traf) + + // tx=265468 rx=263189 (2279 = 0.86% loss) (2162.0 Mbits/sec) + case 21: + setupUDPTest(logf, traf) + + // tx=1493580 rx=1493580 (0 = 0.00% loss) (12210.4 Mbits/sec) + case 31: + setupBatchTCPTest(logf, traf) + + // tx=134236 rx=133166 (1070 = 0.80% loss) (1088.9 Mbits/sec) + case 101: + setupWGTest(nil, logf, traf, Addr1, Addr2) + + default: + log.Fatalf("provide a valid test number (0..n)") + } + + logf("initialized ok.") + traf.Start(Addr1.Addr(), Addr2.Addr(), PayloadSize+ICMPMinSize, 0) + + var cur, prev Snapshot + var pps int64 + i := 0 + for { + i += 1 + time.Sleep(10 * time.Millisecond) + + if (i % 100) == 0 { + prev = cur + cur = traf.Snap() + d := cur.Sub(prev) + + if prev.WhenNsec == 0 { + logf("tx=%-6d rx=%-6d", d.TxPackets, d.RxPackets) + } else { + logf("%v @%7d pkt/s", d, pps) + } + } + + pps = traf.Adjust() + } +} + +func newDebugMux() *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc("/debug/pprof/", pprof.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprof.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + return mux +} + +func runDebugServer(mux *http.ServeMux, addr string) { + srv := &http.Server{ + Addr: addr, + Handler: mux, + } + if err := srv.ListenAndServe(); err != nil { + log.Fatal(err) + } +} + +// The absolute minimal test of the traffic generator: have it fill +// a packet buffer, then absorb it again. Zero packet loss. +func setupTrivialNoAllocTest(logf logger.Logf, traf *TrafficGen) { + go func() { + b := make([]byte, 1600) + for { + n := traf.Generate(b, 16) + if n == 0 { + break + } + traf.GotPacket(b[0:n+16], 16) + } + }() +} + +// Almost the same, but this time allocate a fresh buffer each time +// through the loop. Still zero packet loss. Runs about 2/3 as fast for me. +func setupTrivialTest(logf logger.Logf, traf *TrafficGen) { + go func() { + for { + b := make([]byte, 1600) + n := traf.Generate(b, 16) + if n == 0 { + break + } + traf.GotPacket(b[0:n+16], 16) + } + }() +} + +// Pass packets through a blocking channel between sender and receiver. +// Still zero packet loss since the sender stops when the channel is full. +// Max speed depends on channel length (I'm not sure why). +func setupBlockingChannelTest(logf logger.Logf, traf *TrafficGen) { + ch := make(chan []byte, 1000) + + go func() { + // transmitter + for { + b := make([]byte, 1600) + n := traf.Generate(b, 16) + if n == 0 { + close(ch) + break + } + ch <- b[0 : n+16] + } + }() + + go func() { + // receiver + for b := range ch { + traf.GotPacket(b, 16) + } + }() +} + +// Same as setupBlockingChannelTest, but now we drop packets whenever the +// channel is full. Max speed is about the same as the above test, but +// now with nonzero packet loss. +func setupNonblockingChannelTest(logf logger.Logf, traf *TrafficGen) { + ch := make(chan []byte, 1000) + + go func() { + // transmitter + for { + b := make([]byte, 1600) + n := traf.Generate(b, 16) + if n == 0 { + close(ch) + break + } + select { + case ch <- b[0 : n+16]: + default: + } + } + }() + + go func() { + // receiver + for b := range ch { + traf.GotPacket(b, 16) + } + }() +} + +// Same as above, but at an intermediate blocking channel and goroutine +// to make things a little more like wireguard-go. Roughly 20% slower than +// the single-channel version. +func setupDoubleChannelTest(logf logger.Logf, traf *TrafficGen) { + ch := make(chan []byte, 1000) + ch2 := make(chan []byte, 1000) + + go func() { + // transmitter + for { + b := make([]byte, 1600) + n := traf.Generate(b, 16) + if n == 0 { + close(ch) + break + } + select { + case ch <- b[0 : n+16]: + default: + } + } + }() + + go func() { + // intermediary + for b := range ch { + ch2 <- b + } + close(ch2) + }() + + go func() { + // receiver + for b := range ch2 { + traf.GotPacket(b, 16) + } + }() +} + +// Instead of a channel, pass packets through a UDP socket. +func setupUDPTest(logf logger.Logf, traf *TrafficGen) { + la, err := net.ResolveUDPAddr("udp", ":0") + if err != nil { + log.Fatalf("resolve: %v", err) + } + + s1, err := net.ListenUDP("udp", la) + if err != nil { + log.Fatalf("listen1: %v", err) + } + s2, err := net.ListenUDP("udp", la) + if err != nil { + log.Fatalf("listen2: %v", err) + } + + a2 := s2.LocalAddr() + + // On macOS (but not Linux), you can't transmit to 0.0.0.0:port, + // which is what returns from .LocalAddr() above. We have to + // force it to localhost instead. + a2.(*net.UDPAddr).IP = net.ParseIP("127.0.0.1") + + s1.SetWriteBuffer(1024 * 1024) + s2.SetReadBuffer(1024 * 1024) + + go func() { + // transmitter + b := make([]byte, 1600) + for { + n := traf.Generate(b, 16) + if n == 0 { + break + } + s1.WriteTo(b[16:n+16], a2) + } + }() + + go func() { + // receiver + b := make([]byte, 1600) + for traf.Running() { + // Use ReadFrom instead of Read, to be more like + // how wireguard-go does it, even though we're not + // going to actually look at the address. + n, _, err := s2.ReadFrom(b) + if err != nil { + log.Fatalf("s2.Read: %v", err) + } + traf.GotPacket(b[:n], 0) + } + }() +} + +// Instead of a channel, pass packets through a TCP socket. +// TCP is a single stream, so we can amortize one syscall across +// multiple packets. 10x amortization seems to make it go ~10x faster, +// as expected, getting us close to the speed of the channel tests above. +// There's also zero packet loss. +func setupBatchTCPTest(logf logger.Logf, traf *TrafficGen) { + sl, err := net.Listen("tcp", ":0") + if err != nil { + log.Fatalf("listen: %v", err) + } + + var slCloseOnce sync.Once + slClose := func() { + slCloseOnce.Do(func() { + sl.Close() + }) + } + + s1, err := net.Dial("tcp", sl.Addr().String()) + if err != nil { + log.Fatalf("dial: %v", err) + } + + s2, err := sl.Accept() + if err != nil { + log.Fatalf("accept: %v", err) + } + + s1.(*net.TCPConn).SetWriteBuffer(1024 * 1024) + s2.(*net.TCPConn).SetReadBuffer(1024 * 1024) + + ch := make(chan int) + + go func() { + // transmitter + defer slClose() + defer s1.Close() + + bs1 := bufio.NewWriterSize(s1, 1024*1024) + + b := make([]byte, 1600) + i := 0 + for { + i += 1 + n := traf.Generate(b, 16) + if n == 0 { + break + } + if i == 1 { + ch <- n + } + bs1.Write(b[16 : n+16]) + + // TODO: this is a pretty half-baked batching + // function, which we'd never want to employ in + // a real-life program. + // + // In real life, we'd probably want to flush + // immediately when there are no more packets to + // generate, and queue up only if we fall behind. + // + // In our case however, we just want to see the + // technical benefits of batching 10 syscalls + // into 1, so a fixed ratio makes more sense. + if (i % 10) == 0 { + bs1.Flush() + } + } + }() + + go func() { + // receiver + defer slClose() + defer s2.Close() + + bs2 := bufio.NewReaderSize(s2, 1024*1024) + + // Find out the packet size (we happen to know they're + // all the same size) + packetSize := <-ch + + b := make([]byte, packetSize) + for traf.Running() { + // TODO: can't use ReadFrom() here, which is + // unfair compared to UDP. (ReadFrom for UDP + // apparently allocates memory per packet, which + // this test does not.) + n, err := io.ReadFull(bs2, b) + if err != nil { + log.Fatalf("s2.Read: %v", err) + } + traf.GotPacket(b[:n], 0) + } + }() +} diff --git a/wgengine/bench/bench_test.go b/wgengine/bench/bench_test.go index 42571d055..4fae86c05 100644 --- a/wgengine/bench/bench_test.go +++ b/wgengine/bench/bench_test.go @@ -1,108 +1,108 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Create two wgengine instances and pass data through them, measuring -// throughput, latency, and packet loss. -package main - -import ( - "fmt" - "testing" - "time" - - "tailscale.com/types/logger" -) - -func BenchmarkTrivialNoAlloc(b *testing.B) { - run(b, setupTrivialNoAllocTest) -} -func BenchmarkTrivial(b *testing.B) { - run(b, setupTrivialTest) -} - -func BenchmarkBlockingChannel(b *testing.B) { - run(b, setupBlockingChannelTest) -} - -func BenchmarkNonblockingChannel(b *testing.B) { - run(b, setupNonblockingChannelTest) -} - -func BenchmarkDoubleChannel(b *testing.B) { - run(b, setupDoubleChannelTest) -} - -func BenchmarkUDP(b *testing.B) { - run(b, setupUDPTest) -} - -func BenchmarkBatchTCP(b *testing.B) { - run(b, setupBatchTCPTest) -} - -func BenchmarkWireGuardTest(b *testing.B) { - b.Skip("https://github.com/tailscale/tailscale/issues/2716") - run(b, func(logf logger.Logf, traf *TrafficGen) { - setupWGTest(b, logf, traf, Addr1, Addr2) - }) -} - -type SetupFunc func(logger.Logf, *TrafficGen) - -func run(b *testing.B, setup SetupFunc) { - sizes := []int{ - ICMPMinSize + 8, - ICMPMinSize + 100, - ICMPMinSize + 1000, - } - - for _, size := range sizes { - b.Run(fmt.Sprintf("%d", size), func(b *testing.B) { - runOnce(b, setup, size) - }) - } -} - -func runOnce(b *testing.B, setup SetupFunc, payload int) { - b.StopTimer() - b.ReportAllocs() - - var logf logger.Logf = b.Logf - if !testing.Verbose() { - logf = logger.Discard - } - - traf := NewTrafficGen(b.StartTimer) - setup(logf, traf) - - logf("initialized. (n=%v)", b.N) - b.SetBytes(int64(payload)) - - traf.Start(Addr1.Addr(), Addr2.Addr(), payload, int64(b.N)) - - var cur, prev Snapshot - var pps int64 - i := 0 - for traf.Running() { - i += 1 - time.Sleep(10 * time.Millisecond) - - if (i % 100) == 0 { - prev = cur - cur = traf.Snap() - d := cur.Sub(prev) - - if prev.WhenNsec != 0 { - logf("%v @%7d pkt/sec", d, pps) - } - } - - pps = traf.Adjust() - } - - cur = traf.Snap() - d := cur.Sub(prev) - loss := float64(d.LostPackets) / float64(d.RxPackets) - - b.ReportMetric(loss*100, "%lost") -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Create two wgengine instances and pass data through them, measuring +// throughput, latency, and packet loss. +package main + +import ( + "fmt" + "testing" + "time" + + "tailscale.com/types/logger" +) + +func BenchmarkTrivialNoAlloc(b *testing.B) { + run(b, setupTrivialNoAllocTest) +} +func BenchmarkTrivial(b *testing.B) { + run(b, setupTrivialTest) +} + +func BenchmarkBlockingChannel(b *testing.B) { + run(b, setupBlockingChannelTest) +} + +func BenchmarkNonblockingChannel(b *testing.B) { + run(b, setupNonblockingChannelTest) +} + +func BenchmarkDoubleChannel(b *testing.B) { + run(b, setupDoubleChannelTest) +} + +func BenchmarkUDP(b *testing.B) { + run(b, setupUDPTest) +} + +func BenchmarkBatchTCP(b *testing.B) { + run(b, setupBatchTCPTest) +} + +func BenchmarkWireGuardTest(b *testing.B) { + b.Skip("https://github.com/tailscale/tailscale/issues/2716") + run(b, func(logf logger.Logf, traf *TrafficGen) { + setupWGTest(b, logf, traf, Addr1, Addr2) + }) +} + +type SetupFunc func(logger.Logf, *TrafficGen) + +func run(b *testing.B, setup SetupFunc) { + sizes := []int{ + ICMPMinSize + 8, + ICMPMinSize + 100, + ICMPMinSize + 1000, + } + + for _, size := range sizes { + b.Run(fmt.Sprintf("%d", size), func(b *testing.B) { + runOnce(b, setup, size) + }) + } +} + +func runOnce(b *testing.B, setup SetupFunc, payload int) { + b.StopTimer() + b.ReportAllocs() + + var logf logger.Logf = b.Logf + if !testing.Verbose() { + logf = logger.Discard + } + + traf := NewTrafficGen(b.StartTimer) + setup(logf, traf) + + logf("initialized. (n=%v)", b.N) + b.SetBytes(int64(payload)) + + traf.Start(Addr1.Addr(), Addr2.Addr(), payload, int64(b.N)) + + var cur, prev Snapshot + var pps int64 + i := 0 + for traf.Running() { + i += 1 + time.Sleep(10 * time.Millisecond) + + if (i % 100) == 0 { + prev = cur + cur = traf.Snap() + d := cur.Sub(prev) + + if prev.WhenNsec != 0 { + logf("%v @%7d pkt/sec", d, pps) + } + } + + pps = traf.Adjust() + } + + cur = traf.Snap() + d := cur.Sub(prev) + loss := float64(d.LostPackets) / float64(d.RxPackets) + + b.ReportMetric(loss*100, "%lost") +} diff --git a/wgengine/bench/trafficgen.go b/wgengine/bench/trafficgen.go index 9de3c2e6b..ce79c616f 100644 --- a/wgengine/bench/trafficgen.go +++ b/wgengine/bench/trafficgen.go @@ -1,259 +1,259 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "encoding/binary" - "fmt" - "log" - "net/netip" - "sync" - "time" - - "tailscale.com/net/packet" - "tailscale.com/types/ipproto" -) - -type Snapshot struct { - WhenNsec int64 // current time - timeAcc int64 // accumulated time (+NSecPerTx per transmit) - - LastSeqTx int64 // last sequence number sent - LastSeqRx int64 // last sequence number received - TotalLost int64 // packets out-of-order or lost so far - TotalOOO int64 // packets out-of-order so far - TotalBytesRx int64 // total bytes received so far -} - -type Delta struct { - DurationNsec int64 - TxPackets int64 - RxPackets int64 - LostPackets int64 - OOOPackets int64 - Bytes int64 -} - -func (b Snapshot) Sub(a Snapshot) Delta { - return Delta{ - DurationNsec: b.WhenNsec - a.WhenNsec, - TxPackets: b.LastSeqTx - a.LastSeqTx, - RxPackets: (b.LastSeqRx - a.LastSeqRx) - - (b.TotalLost - a.TotalLost) + - (b.TotalOOO - a.TotalOOO), - LostPackets: b.TotalLost - a.TotalLost, - OOOPackets: b.TotalOOO - a.TotalOOO, - Bytes: b.TotalBytesRx - a.TotalBytesRx, - } -} - -func (d Delta) String() string { - return fmt.Sprintf("tx=%-6d rx=%-4d (%6d = %.1f%% loss) (%d OOO) (%4.1f Mbit/s)", - d.TxPackets, d.RxPackets, d.LostPackets, - float64(d.LostPackets)*100/float64(d.TxPackets), - d.OOOPackets, - float64(d.Bytes)*8*1e9/float64(d.DurationNsec)/1e6) -} - -type TrafficGen struct { - mu sync.Mutex - cur, prev Snapshot // snapshots used for rate control - buf []byte // pre-generated packet buffer - done bool // true if the test has completed - - onFirstPacket func() // function to call on first received packet - - // maxPackets is the max packets to receive (not send) before - // ending the test. If it's zero, the test runs forever. - maxPackets int64 - - // nsPerPacket is the target average nanoseconds between packets. - // It's initially zero, which means transmit as fast as the - // caller wants to go. - nsPerPacket int64 - - // ppsHistory is the observed packets-per-second from recent - // samples. - ppsHistory [5]int64 -} - -// NewTrafficGen creates a new, initially locked, TrafficGen. -// Until Start() is called, Generate() will block forever. -func NewTrafficGen(onFirstPacket func()) *TrafficGen { - t := TrafficGen{ - onFirstPacket: onFirstPacket, - } - - // initially locked, until first Start() - t.mu.Lock() - - return &t -} - -// Start starts the traffic generator. It assumes mu is already locked, -// and unlocks it. -func (t *TrafficGen) Start(src, dst netip.Addr, bytesPerPacket int, maxPackets int64) { - h12 := packet.ICMP4Header{ - IP4Header: packet.IP4Header{ - IPProto: ipproto.ICMPv4, - IPID: 0, - Src: src, - Dst: dst, - }, - Type: packet.ICMP4EchoRequest, - Code: packet.ICMP4NoCode, - } - - // ensure there's room for ICMP header plus sequence number - if bytesPerPacket < ICMPMinSize+8 { - log.Fatalf("bytesPerPacket must be > 24+8") - } - - t.maxPackets = maxPackets - - payload := make([]byte, bytesPerPacket-ICMPMinSize) - t.buf = packet.Generate(h12, payload) - - t.mu.Unlock() -} - -func (t *TrafficGen) Snap() Snapshot { - t.mu.Lock() - defer t.mu.Unlock() - - t.cur.WhenNsec = time.Now().UnixNano() - return t.cur -} - -func (t *TrafficGen) Running() bool { - t.mu.Lock() - defer t.mu.Unlock() - - return !t.done -} - -// Generate produces the next packet in the sequence. It sleeps if -// it's too soon for the next packet to be sent. -// -// The generated packet is placed into buf at offset ofs, for compatibility -// with the wireguard-go conventions. -// -// The return value is the number of bytes generated in the packet, or 0 -// if the test has finished running. -func (t *TrafficGen) Generate(b []byte, ofs int) int { - t.mu.Lock() - - now := time.Now().UnixNano() - if t.nsPerPacket == 0 || t.cur.timeAcc == 0 { - t.cur.timeAcc = now - 1 - } - if t.cur.timeAcc >= now { - // too soon - t.mu.Unlock() - time.Sleep(time.Duration(t.cur.timeAcc-now) * time.Nanosecond) - t.mu.Lock() - - now = t.cur.timeAcc - } - if t.done { - t.mu.Unlock() - return 0 - } - - t.cur.timeAcc += t.nsPerPacket - t.cur.LastSeqTx += 1 - t.cur.WhenNsec = now - seq := t.cur.LastSeqTx - - t.mu.Unlock() - - copy(b[ofs:], t.buf) - binary.BigEndian.PutUint64( - b[ofs+ICMPMinSize:ofs+ICMPMinSize+8], - uint64(seq)) - - return len(t.buf) -} - -// GotPacket processes a packet that came back on the receive side. -func (t *TrafficGen) GotPacket(b []byte, ofs int) { - t.mu.Lock() - defer t.mu.Unlock() - - s := &t.cur - seq := int64(binary.BigEndian.Uint64( - b[ofs+ICMPMinSize : ofs+ICMPMinSize+8])) - if seq > s.LastSeqRx { - if s.LastSeqRx > 0 { - // only count lost packets after the very first - // successful one. - s.TotalLost += seq - s.LastSeqRx - 1 - } - s.LastSeqRx = seq - } else { - s.TotalOOO += 1 - } - - // +1 packet since we only start counting after the first one - if t.maxPackets > 0 && s.LastSeqRx >= t.maxPackets+1 { - t.done = true - } - s.TotalBytesRx += int64(len(b) - ofs) - - f := t.onFirstPacket - t.onFirstPacket = nil - if f != nil { - f() - } -} - -// Adjust tunes the transmit rate based on the received packets. -// The goal is to converge on the fastest transmit rate that still has -// minimal packet loss. Returns the new target rate in packets/sec. -// -// We need to play this guessing game in order to balance out tx and rx -// rates when there's a lossy network between them. Otherwise we can end -// up using 99% of the CPU to blast out transmitted packets and leaving only -// 1% to receive them, leading to a misleading throughput calculation. -// -// Call this function multiple times per second. -func (t *TrafficGen) Adjust() (pps int64) { - t.mu.Lock() - defer t.mu.Unlock() - - d := t.cur.Sub(t.prev) - - // don't adjust rate until the first full period *after* receiving - // the first packet. This skips any handshake time in the underlying - // transport. - if t.prev.LastSeqRx == 0 || d.DurationNsec == 0 { - t.prev = t.cur - return 0 // no estimate yet, continue at max speed - } - - pps = int64(d.RxPackets) * 1e9 / int64(d.DurationNsec) - - // We use a rate selection algorithm based loosely on TCP BBR. - // Basically, we set the transmit rate to be a bit higher than - // the best observed transmit rate in the last several time - // periods. This guarantees some packet loss, but should converge - // quickly on a rate near the sustainable maximum. - bestPPS := pps - for _, p := range t.ppsHistory { - if p > bestPPS { - bestPPS = p - } - } - if pps > 0 && t.prev.WhenNsec > 0 { - copy(t.ppsHistory[1:], t.ppsHistory[0:len(t.ppsHistory)-1]) - t.ppsHistory[0] = pps - } - if bestPPS > 0 { - pps = bestPPS * 103 / 100 - t.nsPerPacket = int64(1e9 / pps) - } - t.prev = t.cur - - return pps -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "encoding/binary" + "fmt" + "log" + "net/netip" + "sync" + "time" + + "tailscale.com/net/packet" + "tailscale.com/types/ipproto" +) + +type Snapshot struct { + WhenNsec int64 // current time + timeAcc int64 // accumulated time (+NSecPerTx per transmit) + + LastSeqTx int64 // last sequence number sent + LastSeqRx int64 // last sequence number received + TotalLost int64 // packets out-of-order or lost so far + TotalOOO int64 // packets out-of-order so far + TotalBytesRx int64 // total bytes received so far +} + +type Delta struct { + DurationNsec int64 + TxPackets int64 + RxPackets int64 + LostPackets int64 + OOOPackets int64 + Bytes int64 +} + +func (b Snapshot) Sub(a Snapshot) Delta { + return Delta{ + DurationNsec: b.WhenNsec - a.WhenNsec, + TxPackets: b.LastSeqTx - a.LastSeqTx, + RxPackets: (b.LastSeqRx - a.LastSeqRx) - + (b.TotalLost - a.TotalLost) + + (b.TotalOOO - a.TotalOOO), + LostPackets: b.TotalLost - a.TotalLost, + OOOPackets: b.TotalOOO - a.TotalOOO, + Bytes: b.TotalBytesRx - a.TotalBytesRx, + } +} + +func (d Delta) String() string { + return fmt.Sprintf("tx=%-6d rx=%-4d (%6d = %.1f%% loss) (%d OOO) (%4.1f Mbit/s)", + d.TxPackets, d.RxPackets, d.LostPackets, + float64(d.LostPackets)*100/float64(d.TxPackets), + d.OOOPackets, + float64(d.Bytes)*8*1e9/float64(d.DurationNsec)/1e6) +} + +type TrafficGen struct { + mu sync.Mutex + cur, prev Snapshot // snapshots used for rate control + buf []byte // pre-generated packet buffer + done bool // true if the test has completed + + onFirstPacket func() // function to call on first received packet + + // maxPackets is the max packets to receive (not send) before + // ending the test. If it's zero, the test runs forever. + maxPackets int64 + + // nsPerPacket is the target average nanoseconds between packets. + // It's initially zero, which means transmit as fast as the + // caller wants to go. + nsPerPacket int64 + + // ppsHistory is the observed packets-per-second from recent + // samples. + ppsHistory [5]int64 +} + +// NewTrafficGen creates a new, initially locked, TrafficGen. +// Until Start() is called, Generate() will block forever. +func NewTrafficGen(onFirstPacket func()) *TrafficGen { + t := TrafficGen{ + onFirstPacket: onFirstPacket, + } + + // initially locked, until first Start() + t.mu.Lock() + + return &t +} + +// Start starts the traffic generator. It assumes mu is already locked, +// and unlocks it. +func (t *TrafficGen) Start(src, dst netip.Addr, bytesPerPacket int, maxPackets int64) { + h12 := packet.ICMP4Header{ + IP4Header: packet.IP4Header{ + IPProto: ipproto.ICMPv4, + IPID: 0, + Src: src, + Dst: dst, + }, + Type: packet.ICMP4EchoRequest, + Code: packet.ICMP4NoCode, + } + + // ensure there's room for ICMP header plus sequence number + if bytesPerPacket < ICMPMinSize+8 { + log.Fatalf("bytesPerPacket must be > 24+8") + } + + t.maxPackets = maxPackets + + payload := make([]byte, bytesPerPacket-ICMPMinSize) + t.buf = packet.Generate(h12, payload) + + t.mu.Unlock() +} + +func (t *TrafficGen) Snap() Snapshot { + t.mu.Lock() + defer t.mu.Unlock() + + t.cur.WhenNsec = time.Now().UnixNano() + return t.cur +} + +func (t *TrafficGen) Running() bool { + t.mu.Lock() + defer t.mu.Unlock() + + return !t.done +} + +// Generate produces the next packet in the sequence. It sleeps if +// it's too soon for the next packet to be sent. +// +// The generated packet is placed into buf at offset ofs, for compatibility +// with the wireguard-go conventions. +// +// The return value is the number of bytes generated in the packet, or 0 +// if the test has finished running. +func (t *TrafficGen) Generate(b []byte, ofs int) int { + t.mu.Lock() + + now := time.Now().UnixNano() + if t.nsPerPacket == 0 || t.cur.timeAcc == 0 { + t.cur.timeAcc = now - 1 + } + if t.cur.timeAcc >= now { + // too soon + t.mu.Unlock() + time.Sleep(time.Duration(t.cur.timeAcc-now) * time.Nanosecond) + t.mu.Lock() + + now = t.cur.timeAcc + } + if t.done { + t.mu.Unlock() + return 0 + } + + t.cur.timeAcc += t.nsPerPacket + t.cur.LastSeqTx += 1 + t.cur.WhenNsec = now + seq := t.cur.LastSeqTx + + t.mu.Unlock() + + copy(b[ofs:], t.buf) + binary.BigEndian.PutUint64( + b[ofs+ICMPMinSize:ofs+ICMPMinSize+8], + uint64(seq)) + + return len(t.buf) +} + +// GotPacket processes a packet that came back on the receive side. +func (t *TrafficGen) GotPacket(b []byte, ofs int) { + t.mu.Lock() + defer t.mu.Unlock() + + s := &t.cur + seq := int64(binary.BigEndian.Uint64( + b[ofs+ICMPMinSize : ofs+ICMPMinSize+8])) + if seq > s.LastSeqRx { + if s.LastSeqRx > 0 { + // only count lost packets after the very first + // successful one. + s.TotalLost += seq - s.LastSeqRx - 1 + } + s.LastSeqRx = seq + } else { + s.TotalOOO += 1 + } + + // +1 packet since we only start counting after the first one + if t.maxPackets > 0 && s.LastSeqRx >= t.maxPackets+1 { + t.done = true + } + s.TotalBytesRx += int64(len(b) - ofs) + + f := t.onFirstPacket + t.onFirstPacket = nil + if f != nil { + f() + } +} + +// Adjust tunes the transmit rate based on the received packets. +// The goal is to converge on the fastest transmit rate that still has +// minimal packet loss. Returns the new target rate in packets/sec. +// +// We need to play this guessing game in order to balance out tx and rx +// rates when there's a lossy network between them. Otherwise we can end +// up using 99% of the CPU to blast out transmitted packets and leaving only +// 1% to receive them, leading to a misleading throughput calculation. +// +// Call this function multiple times per second. +func (t *TrafficGen) Adjust() (pps int64) { + t.mu.Lock() + defer t.mu.Unlock() + + d := t.cur.Sub(t.prev) + + // don't adjust rate until the first full period *after* receiving + // the first packet. This skips any handshake time in the underlying + // transport. + if t.prev.LastSeqRx == 0 || d.DurationNsec == 0 { + t.prev = t.cur + return 0 // no estimate yet, continue at max speed + } + + pps = int64(d.RxPackets) * 1e9 / int64(d.DurationNsec) + + // We use a rate selection algorithm based loosely on TCP BBR. + // Basically, we set the transmit rate to be a bit higher than + // the best observed transmit rate in the last several time + // periods. This guarantees some packet loss, but should converge + // quickly on a rate near the sustainable maximum. + bestPPS := pps + for _, p := range t.ppsHistory { + if p > bestPPS { + bestPPS = p + } + } + if pps > 0 && t.prev.WhenNsec > 0 { + copy(t.ppsHistory[1:], t.ppsHistory[0:len(t.ppsHistory)-1]) + t.ppsHistory[0] = pps + } + if bestPPS > 0 { + pps = bestPPS * 103 / 100 + t.nsPerPacket = int64(1e9 / pps) + } + t.prev = t.cur + + return pps +} diff --git a/wgengine/capture/capture.go b/wgengine/capture/capture.go index 01f79ea9f..6ea5a9549 100644 --- a/wgengine/capture/capture.go +++ b/wgengine/capture/capture.go @@ -1,238 +1,238 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package capture formats packet logging into a debug pcap stream. -package capture - -import ( - "bytes" - "context" - "encoding/binary" - "io" - "net/http" - "sync" - "time" - - _ "embed" - - "tailscale.com/net/packet" - "tailscale.com/util/set" -) - -//go:embed ts-dissector.lua -var DissectorLua string - -// Callback describes a function which is called to -// record packets when debugging packet-capture. -// Such callbacks must not take ownership of the -// provided data slice: it may only copy out of it -// within the lifetime of the function. -type Callback func(Path, time.Time, []byte, packet.CaptureMeta) - -var bufferPool = sync.Pool{ - New: func() any { - return new(bytes.Buffer) - }, -} - -const flushPeriod = 100 * time.Millisecond - -func writePcapHeader(w io.Writer) { - binary.Write(w, binary.LittleEndian, uint32(0xA1B2C3D4)) // pcap magic number - binary.Write(w, binary.LittleEndian, uint16(2)) // version major - binary.Write(w, binary.LittleEndian, uint16(4)) // version minor - binary.Write(w, binary.LittleEndian, uint32(0)) // this zone - binary.Write(w, binary.LittleEndian, uint32(0)) // zone significant figures - binary.Write(w, binary.LittleEndian, uint32(65535)) // max packet len - binary.Write(w, binary.LittleEndian, uint32(147)) // link-layer ID - USER0 -} - -func writePktHeader(w *bytes.Buffer, when time.Time, length int) { - s := when.Unix() - us := when.UnixMicro() - (s * 1000000) - - binary.Write(w, binary.LittleEndian, uint32(s)) // timestamp in seconds - binary.Write(w, binary.LittleEndian, uint32(us)) // timestamp microseconds - binary.Write(w, binary.LittleEndian, uint32(length)) // length present - binary.Write(w, binary.LittleEndian, uint32(length)) // total length -} - -// Path describes where in the data path the packet was captured. -type Path uint8 - -// Valid Path values. -const ( - // FromLocal indicates the packet was logged as it traversed the FromLocal path: - // i.e.: A packet from the local system into the TUN. - FromLocal Path = 0 - // FromPeer indicates the packet was logged upon reception from a remote peer. - FromPeer Path = 1 - // SynthesizedToLocal indicates the packet was generated from within tailscaled, - // and is being routed to the local machine's network stack. - SynthesizedToLocal Path = 2 - // SynthesizedToPeer indicates the packet was generated from within tailscaled, - // and is being routed to a remote Wireguard peer. - SynthesizedToPeer Path = 3 - - // PathDisco indicates the packet is information about a disco frame. - PathDisco Path = 254 -) - -// New creates a new capture sink. -func New() *Sink { - ctx, c := context.WithCancel(context.Background()) - return &Sink{ - ctx: ctx, - ctxCancel: c, - } -} - -// Type Sink handles callbacks with packets to be logged, -// formatting them into a pcap stream which is mirrored to -// all registered outputs. -type Sink struct { - ctx context.Context - ctxCancel context.CancelFunc - - mu sync.Mutex - outputs set.HandleSet[io.Writer] - flushTimer *time.Timer // or nil if none running -} - -// RegisterOutput connects an output to this sink, which -// will be written to with a pcap stream as packets are logged. -// A function is returned which unregisters the output when -// called. -// -// If w implements io.Closer, it will be closed upon error -// or when the sink is closed. If w implements http.Flusher, -// it will be flushed periodically. -func (s *Sink) RegisterOutput(w io.Writer) (unregister func()) { - select { - case <-s.ctx.Done(): - return func() {} - default: - } - - writePcapHeader(w) - s.mu.Lock() - hnd := s.outputs.Add(w) - s.mu.Unlock() - - return func() { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.outputs, hnd) - } -} - -// NumOutputs returns the number of outputs registered with the sink. -func (s *Sink) NumOutputs() int { - s.mu.Lock() - defer s.mu.Unlock() - return len(s.outputs) -} - -// Close shuts down the sink. Future calls to LogPacket -// are ignored, and any registered output that implements -// io.Closer is closed. -func (s *Sink) Close() error { - s.ctxCancel() - s.mu.Lock() - defer s.mu.Unlock() - if s.flushTimer != nil { - s.flushTimer.Stop() - s.flushTimer = nil - } - - for _, o := range s.outputs { - if o, ok := o.(io.Closer); ok { - o.Close() - } - } - s.outputs = nil - return nil -} - -// WaitCh returns a channel which blocks until -// the sink is closed. -func (s *Sink) WaitCh() <-chan struct{} { - return s.ctx.Done() -} - -func customDataLen(meta packet.CaptureMeta) int { - length := 4 - if meta.DidSNAT { - length += meta.OriginalSrc.Addr().BitLen() / 8 - } - if meta.DidDNAT { - length += meta.OriginalDst.Addr().BitLen() / 8 - } - return length -} - -// LogPacket is called to insert a packet into the capture. -// -// This function does not take ownership of the provided data slice. -func (s *Sink) LogPacket(path Path, when time.Time, data []byte, meta packet.CaptureMeta) { - select { - case <-s.ctx.Done(): - return - default: - } - - extraLen := customDataLen(meta) - b := bufferPool.Get().(*bytes.Buffer) - b.Reset() - b.Grow(16 + extraLen + len(data)) // 16b pcap header + len(metadata) + len(payload) - defer bufferPool.Put(b) - - writePktHeader(b, when, len(data)+extraLen) - - // Custom tailscale debugging data - binary.Write(b, binary.LittleEndian, uint16(path)) - if meta.DidSNAT { - binary.Write(b, binary.LittleEndian, uint8(meta.OriginalSrc.Addr().BitLen()/8)) - b.Write(meta.OriginalSrc.Addr().AsSlice()) - } else { - binary.Write(b, binary.LittleEndian, uint8(0)) // SNAT addr len == 0 - } - if meta.DidDNAT { - binary.Write(b, binary.LittleEndian, uint8(meta.OriginalDst.Addr().BitLen()/8)) - b.Write(meta.OriginalDst.Addr().AsSlice()) - } else { - binary.Write(b, binary.LittleEndian, uint8(0)) // DNAT addr len == 0 - } - - b.Write(data) - - s.mu.Lock() - defer s.mu.Unlock() - - var hadError []set.Handle - for hnd, o := range s.outputs { - if _, err := o.Write(b.Bytes()); err != nil { - hadError = append(hadError, hnd) - continue - } - } - for _, hnd := range hadError { - if o, ok := s.outputs[hnd].(io.Closer); ok { - o.Close() - } - delete(s.outputs, hnd) - } - - if s.flushTimer == nil { - s.flushTimer = time.AfterFunc(flushPeriod, func() { - s.mu.Lock() - defer s.mu.Unlock() - for _, o := range s.outputs { - if f, ok := o.(http.Flusher); ok { - f.Flush() - } - } - s.flushTimer = nil - }) - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package capture formats packet logging into a debug pcap stream. +package capture + +import ( + "bytes" + "context" + "encoding/binary" + "io" + "net/http" + "sync" + "time" + + _ "embed" + + "tailscale.com/net/packet" + "tailscale.com/util/set" +) + +//go:embed ts-dissector.lua +var DissectorLua string + +// Callback describes a function which is called to +// record packets when debugging packet-capture. +// Such callbacks must not take ownership of the +// provided data slice: it may only copy out of it +// within the lifetime of the function. +type Callback func(Path, time.Time, []byte, packet.CaptureMeta) + +var bufferPool = sync.Pool{ + New: func() any { + return new(bytes.Buffer) + }, +} + +const flushPeriod = 100 * time.Millisecond + +func writePcapHeader(w io.Writer) { + binary.Write(w, binary.LittleEndian, uint32(0xA1B2C3D4)) // pcap magic number + binary.Write(w, binary.LittleEndian, uint16(2)) // version major + binary.Write(w, binary.LittleEndian, uint16(4)) // version minor + binary.Write(w, binary.LittleEndian, uint32(0)) // this zone + binary.Write(w, binary.LittleEndian, uint32(0)) // zone significant figures + binary.Write(w, binary.LittleEndian, uint32(65535)) // max packet len + binary.Write(w, binary.LittleEndian, uint32(147)) // link-layer ID - USER0 +} + +func writePktHeader(w *bytes.Buffer, when time.Time, length int) { + s := when.Unix() + us := when.UnixMicro() - (s * 1000000) + + binary.Write(w, binary.LittleEndian, uint32(s)) // timestamp in seconds + binary.Write(w, binary.LittleEndian, uint32(us)) // timestamp microseconds + binary.Write(w, binary.LittleEndian, uint32(length)) // length present + binary.Write(w, binary.LittleEndian, uint32(length)) // total length +} + +// Path describes where in the data path the packet was captured. +type Path uint8 + +// Valid Path values. +const ( + // FromLocal indicates the packet was logged as it traversed the FromLocal path: + // i.e.: A packet from the local system into the TUN. + FromLocal Path = 0 + // FromPeer indicates the packet was logged upon reception from a remote peer. + FromPeer Path = 1 + // SynthesizedToLocal indicates the packet was generated from within tailscaled, + // and is being routed to the local machine's network stack. + SynthesizedToLocal Path = 2 + // SynthesizedToPeer indicates the packet was generated from within tailscaled, + // and is being routed to a remote Wireguard peer. + SynthesizedToPeer Path = 3 + + // PathDisco indicates the packet is information about a disco frame. + PathDisco Path = 254 +) + +// New creates a new capture sink. +func New() *Sink { + ctx, c := context.WithCancel(context.Background()) + return &Sink{ + ctx: ctx, + ctxCancel: c, + } +} + +// Type Sink handles callbacks with packets to be logged, +// formatting them into a pcap stream which is mirrored to +// all registered outputs. +type Sink struct { + ctx context.Context + ctxCancel context.CancelFunc + + mu sync.Mutex + outputs set.HandleSet[io.Writer] + flushTimer *time.Timer // or nil if none running +} + +// RegisterOutput connects an output to this sink, which +// will be written to with a pcap stream as packets are logged. +// A function is returned which unregisters the output when +// called. +// +// If w implements io.Closer, it will be closed upon error +// or when the sink is closed. If w implements http.Flusher, +// it will be flushed periodically. +func (s *Sink) RegisterOutput(w io.Writer) (unregister func()) { + select { + case <-s.ctx.Done(): + return func() {} + default: + } + + writePcapHeader(w) + s.mu.Lock() + hnd := s.outputs.Add(w) + s.mu.Unlock() + + return func() { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.outputs, hnd) + } +} + +// NumOutputs returns the number of outputs registered with the sink. +func (s *Sink) NumOutputs() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.outputs) +} + +// Close shuts down the sink. Future calls to LogPacket +// are ignored, and any registered output that implements +// io.Closer is closed. +func (s *Sink) Close() error { + s.ctxCancel() + s.mu.Lock() + defer s.mu.Unlock() + if s.flushTimer != nil { + s.flushTimer.Stop() + s.flushTimer = nil + } + + for _, o := range s.outputs { + if o, ok := o.(io.Closer); ok { + o.Close() + } + } + s.outputs = nil + return nil +} + +// WaitCh returns a channel which blocks until +// the sink is closed. +func (s *Sink) WaitCh() <-chan struct{} { + return s.ctx.Done() +} + +func customDataLen(meta packet.CaptureMeta) int { + length := 4 + if meta.DidSNAT { + length += meta.OriginalSrc.Addr().BitLen() / 8 + } + if meta.DidDNAT { + length += meta.OriginalDst.Addr().BitLen() / 8 + } + return length +} + +// LogPacket is called to insert a packet into the capture. +// +// This function does not take ownership of the provided data slice. +func (s *Sink) LogPacket(path Path, when time.Time, data []byte, meta packet.CaptureMeta) { + select { + case <-s.ctx.Done(): + return + default: + } + + extraLen := customDataLen(meta) + b := bufferPool.Get().(*bytes.Buffer) + b.Reset() + b.Grow(16 + extraLen + len(data)) // 16b pcap header + len(metadata) + len(payload) + defer bufferPool.Put(b) + + writePktHeader(b, when, len(data)+extraLen) + + // Custom tailscale debugging data + binary.Write(b, binary.LittleEndian, uint16(path)) + if meta.DidSNAT { + binary.Write(b, binary.LittleEndian, uint8(meta.OriginalSrc.Addr().BitLen()/8)) + b.Write(meta.OriginalSrc.Addr().AsSlice()) + } else { + binary.Write(b, binary.LittleEndian, uint8(0)) // SNAT addr len == 0 + } + if meta.DidDNAT { + binary.Write(b, binary.LittleEndian, uint8(meta.OriginalDst.Addr().BitLen()/8)) + b.Write(meta.OriginalDst.Addr().AsSlice()) + } else { + binary.Write(b, binary.LittleEndian, uint8(0)) // DNAT addr len == 0 + } + + b.Write(data) + + s.mu.Lock() + defer s.mu.Unlock() + + var hadError []set.Handle + for hnd, o := range s.outputs { + if _, err := o.Write(b.Bytes()); err != nil { + hadError = append(hadError, hnd) + continue + } + } + for _, hnd := range hadError { + if o, ok := s.outputs[hnd].(io.Closer); ok { + o.Close() + } + delete(s.outputs, hnd) + } + + if s.flushTimer == nil { + s.flushTimer = time.AfterFunc(flushPeriod, func() { + s.mu.Lock() + defer s.mu.Unlock() + for _, o := range s.outputs { + if f, ok := o.(http.Flusher); ok { + f.Flush() + } + } + s.flushTimer = nil + }) + } +} diff --git a/wgengine/magicsock/blockforever_conn.go b/wgengine/magicsock/blockforever_conn.go index 58359acdd..f2e85dcd5 100644 --- a/wgengine/magicsock/blockforever_conn.go +++ b/wgengine/magicsock/blockforever_conn.go @@ -1,55 +1,55 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package magicsock - -import ( - "errors" - "net" - "net/netip" - "sync" - "syscall" - "time" -) - -// blockForeverConn is a net.PacketConn whose reads block until it is closed. -type blockForeverConn struct { - mu sync.Mutex - cond *sync.Cond - closed bool -} - -func (c *blockForeverConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) { - c.mu.Lock() - for !c.closed { - c.cond.Wait() - } - c.mu.Unlock() - return 0, netip.AddrPort{}, net.ErrClosed -} - -func (c *blockForeverConn) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (int, error) { - // Silently drop writes. - return len(p), nil -} - -func (c *blockForeverConn) LocalAddr() net.Addr { - // Return a *net.UDPAddr because lots of code assumes that it will. - return new(net.UDPAddr) -} - -func (c *blockForeverConn) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return net.ErrClosed - } - c.closed = true - c.cond.Broadcast() - return nil -} - -func (c *blockForeverConn) SetDeadline(t time.Time) error { return errors.New("unimplemented") } -func (c *blockForeverConn) SetReadDeadline(t time.Time) error { return errors.New("unimplemented") } -func (c *blockForeverConn) SetWriteDeadline(t time.Time) error { return errors.New("unimplemented") } -func (c *blockForeverConn) SyscallConn() (syscall.RawConn, error) { return nil, errUnsupportedConnType } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "errors" + "net" + "net/netip" + "sync" + "syscall" + "time" +) + +// blockForeverConn is a net.PacketConn whose reads block until it is closed. +type blockForeverConn struct { + mu sync.Mutex + cond *sync.Cond + closed bool +} + +func (c *blockForeverConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) { + c.mu.Lock() + for !c.closed { + c.cond.Wait() + } + c.mu.Unlock() + return 0, netip.AddrPort{}, net.ErrClosed +} + +func (c *blockForeverConn) WriteToUDPAddrPort(p []byte, addr netip.AddrPort) (int, error) { + // Silently drop writes. + return len(p), nil +} + +func (c *blockForeverConn) LocalAddr() net.Addr { + // Return a *net.UDPAddr because lots of code assumes that it will. + return new(net.UDPAddr) +} + +func (c *blockForeverConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return net.ErrClosed + } + c.closed = true + c.cond.Broadcast() + return nil +} + +func (c *blockForeverConn) SetDeadline(t time.Time) error { return errors.New("unimplemented") } +func (c *blockForeverConn) SetReadDeadline(t time.Time) error { return errors.New("unimplemented") } +func (c *blockForeverConn) SetWriteDeadline(t time.Time) error { return errors.New("unimplemented") } +func (c *blockForeverConn) SyscallConn() (syscall.RawConn, error) { return nil, errUnsupportedConnType } diff --git a/wgengine/magicsock/endpoint_default.go b/wgengine/magicsock/endpoint_default.go index 9ffeef5f8..1ed6e5e0e 100644 --- a/wgengine/magicsock/endpoint_default.go +++ b/wgengine/magicsock/endpoint_default.go @@ -1,22 +1,22 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !js && !wasm && !plan9 - -package magicsock - -import ( - "errors" - "syscall" -) - -// errHOSTUNREACH wraps unix.EHOSTUNREACH in an interface type to pass to -// errors.Is while avoiding an allocation per call. -var errHOSTUNREACH error = syscall.EHOSTUNREACH - -// isBadEndpointErr checks if err is one which is known to report that an -// endpoint can no longer be sent to. It is not exhaustive, and for unknown -// errors always reports false. -func isBadEndpointErr(err error) bool { - return errors.Is(err, errHOSTUNREACH) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !js && !wasm && !plan9 + +package magicsock + +import ( + "errors" + "syscall" +) + +// errHOSTUNREACH wraps unix.EHOSTUNREACH in an interface type to pass to +// errors.Is while avoiding an allocation per call. +var errHOSTUNREACH error = syscall.EHOSTUNREACH + +// isBadEndpointErr checks if err is one which is known to report that an +// endpoint can no longer be sent to. It is not exhaustive, and for unknown +// errors always reports false. +func isBadEndpointErr(err error) bool { + return errors.Is(err, errHOSTUNREACH) +} diff --git a/wgengine/magicsock/endpoint_stub.go b/wgengine/magicsock/endpoint_stub.go index 9a5c9d937..a209c352b 100644 --- a/wgengine/magicsock/endpoint_stub.go +++ b/wgengine/magicsock/endpoint_stub.go @@ -1,13 +1,13 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build wasm || plan9 - -package magicsock - -// isBadEndpointErr checks if err is one which is known to report that an -// endpoint can no longer be sent to. It is not exhaustive, but covers known -// cases. -func isBadEndpointErr(err error) bool { - return false -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build wasm || plan9 + +package magicsock + +// isBadEndpointErr checks if err is one which is known to report that an +// endpoint can no longer be sent to. It is not exhaustive, but covers known +// cases. +func isBadEndpointErr(err error) bool { + return false +} diff --git a/wgengine/magicsock/endpoint_tracker.go b/wgengine/magicsock/endpoint_tracker.go index e2ac926b4..5caddd1a0 100644 --- a/wgengine/magicsock/endpoint_tracker.go +++ b/wgengine/magicsock/endpoint_tracker.go @@ -1,248 +1,248 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package magicsock - -import ( - "net/netip" - "slices" - "sync" - "time" - - "tailscale.com/tailcfg" - "tailscale.com/tempfork/heap" - "tailscale.com/util/mak" - "tailscale.com/util/set" -) - -const ( - // endpointTrackerLifetime is how long we continue advertising an - // endpoint after we last see it. This is intentionally chosen to be - // slightly longer than a full netcheck period. - endpointTrackerLifetime = 5*time.Minute + 10*time.Second - - // endpointTrackerMaxPerAddr is how many cached addresses we track for - // a given netip.Addr. This allows e.g. restricting the number of STUN - // endpoints we cache (which usually have the same netip.Addr but - // different ports). - // - // The value of 6 is chosen because we can advertise up to 3 endpoints - // based on the STUN IP: - // 1. The STUN endpoint itself (EndpointSTUN) - // 2. The STUN IP with the local Tailscale port (EndpointSTUN4LocalPort) - // 3. The STUN IP with a portmapped port (EndpointPortmapped) - // - // Storing 6 endpoints in the cache means we can store up to 2 previous - // sets of endpoints. - endpointTrackerMaxPerAddr = 6 -) - -// endpointTrackerEntry is an entry in an endpointHeap that stores the state of -// a given cached endpoint. -type endpointTrackerEntry struct { - // endpoint is the cached endpoint. - endpoint tailcfg.Endpoint - // until is the time until which this endpoint is being cached. - until time.Time - // index is the index within the containing endpointHeap. - index int -} - -// endpointHeap is an ordered heap of endpointTrackerEntry structs, ordered in -// ascending order by the 'until' expiry time (i.e. oldest first). -type endpointHeap []*endpointTrackerEntry - -var _ heap.Interface[*endpointTrackerEntry] = (*endpointHeap)(nil) - -// Len implements heap.Interface. -func (eh endpointHeap) Len() int { return len(eh) } - -// Less implements heap.Interface. -func (eh endpointHeap) Less(i, j int) bool { - // We want to store items so that the lowest item in the heap is the - // oldest, so that heap.Pop()-ing from the endpointHeap will remove the - // oldest entry. - return eh[i].until.Before(eh[j].until) -} - -// Swap implements heap.Interface. -func (eh endpointHeap) Swap(i, j int) { - eh[i], eh[j] = eh[j], eh[i] - eh[i].index = i - eh[j].index = j -} - -// Push implements heap.Interface. -func (eh *endpointHeap) Push(item *endpointTrackerEntry) { - n := len(*eh) - item.index = n - *eh = append(*eh, item) -} - -// Pop implements heap.Interface. -func (eh *endpointHeap) Pop() *endpointTrackerEntry { - old := *eh - n := len(old) - item := old[n-1] - old[n-1] = nil // avoid memory leak - item.index = -1 // for safety - *eh = old[0 : n-1] - return item -} - -// Min returns a pointer to the minimum element in the heap, without removing -// it. Since this is a min-heap ordered by the 'until' field, this returns the -// chronologically "earliest" element in the heap. -// -// Len() must be non-zero. -func (eh endpointHeap) Min() *endpointTrackerEntry { - return eh[0] -} - -// endpointTracker caches endpoints that are advertised to peers. This allows -// peers to still reach this node if there's a temporary endpoint flap; rather -// than withdrawing an endpoint and then re-advertising it the next time we run -// a netcheck, we keep advertising the endpoint until it's not present for a -// defined timeout. -// -// See tailscale/tailscale#7877 for more information. -type endpointTracker struct { - mu sync.Mutex - endpoints map[netip.Addr]*endpointHeap -} - -// update takes as input the current sent of discovered endpoints and the -// current time, and returns the set of endpoints plus any previous-cached and -// non-expired endpoints that should be advertised to peers. -func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) { - var inputEps set.Slice[netip.AddrPort] - for _, ep := range eps { - inputEps.Add(ep.Addr) - } - - et.mu.Lock() - defer et.mu.Unlock() - - // Extend endpoints that already exist in the cache. We do this before - // we remove expired endpoints, below, so we don't remove something - // that would otherwise have survived by extending. - until := now.Add(endpointTrackerLifetime) - for _, ep := range eps { - et.extendLocked(ep, until) - } - - // Now that we've extended existing endpoints, remove everything that - // has expired. - et.removeExpiredLocked(now) - - // Add entries from the input set of endpoints into the cache; we do - // this after removing expired ones so that we can store as many as - // possible, with space freed by the entries removed after expiry. - for _, ep := range eps { - et.addLocked(now, ep, until) - } - - // Finally, add entries to the return array that aren't already there. - epsPlusCached = eps - for _, heap := range et.endpoints { - for _, ep := range *heap { - // If the endpoint was in the input list, or has expired, skip it. - if inputEps.Contains(ep.endpoint.Addr) { - continue - } else if now.After(ep.until) { - // Defense-in-depth; should never happen since - // we removed expired entries above, but ignore - // it anyway. - continue - } - - // We haven't seen this endpoint; add to the return array - epsPlusCached = append(epsPlusCached, ep.endpoint) - } - } - - return epsPlusCached -} - -// extendLocked will update the expiry time of the provided endpoint in the -// cache, if it is present. If it is not present, nothing will be done. -// -// et.mu must be held. -func (et *endpointTracker) extendLocked(ep tailcfg.Endpoint, until time.Time) { - key := ep.Addr.Addr() - epHeap, found := et.endpoints[key] - if !found { - return - } - - // Find the entry for this exact address; this loop is quick since we - // bound the number of items in the heap. - // - // TODO(andrew): this means we iterate over the entire heap once per - // endpoint; even if the heap is small, if we have a lot of input - // endpoints this can be expensive? - for i, entry := range *epHeap { - if entry.endpoint == ep { - entry.until = until - heap.Fix(epHeap, i) - return - } - } -} - -// addLocked will store the provided endpoint(s) in the cache for a fixed -// period of time, ensuring that the size of the endpoint cache remains below -// the maximum. -// -// et.mu must be held. -func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) { - key := ep.Addr.Addr() - - // Create or get the heap for this endpoint's addr - epHeap := et.endpoints[key] - if epHeap == nil { - epHeap = new(endpointHeap) - mak.Set(&et.endpoints, key, epHeap) - } - - // Find the entry for this exact address; this loop is quick - // since we bound the number of items in the heap. - found := slices.ContainsFunc(*epHeap, func(v *endpointTrackerEntry) bool { - return v.endpoint == ep - }) - if !found { - // Add address to heap; either the endpoint is new, or the heap - // was newly-created and thus empty. - heap.Push(epHeap, &endpointTrackerEntry{endpoint: ep, until: until}) - } - - // Now that we've added everything, pop from our heap until we're below - // the limit. This is a min-heap, so popping removes the lowest (and - // thus oldest) endpoint. - for epHeap.Len() > endpointTrackerMaxPerAddr { - heap.Pop(epHeap) - } -} - -// removeExpired will remove all expired entries from the cache. -// -// et.mu must be held. -func (et *endpointTracker) removeExpiredLocked(now time.Time) { - for k, epHeap := range et.endpoints { - // The minimum element is oldest/earliest endpoint; repeatedly - // pop from the heap while it's in the past. - for epHeap.Len() > 0 { - minElem := epHeap.Min() - if now.After(minElem.until) { - heap.Pop(epHeap) - } else { - break - } - } - - if epHeap.Len() == 0 { - // Free up space in the map by removing the empty heap. - delete(et.endpoints, k) - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "net/netip" + "slices" + "sync" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/tempfork/heap" + "tailscale.com/util/mak" + "tailscale.com/util/set" +) + +const ( + // endpointTrackerLifetime is how long we continue advertising an + // endpoint after we last see it. This is intentionally chosen to be + // slightly longer than a full netcheck period. + endpointTrackerLifetime = 5*time.Minute + 10*time.Second + + // endpointTrackerMaxPerAddr is how many cached addresses we track for + // a given netip.Addr. This allows e.g. restricting the number of STUN + // endpoints we cache (which usually have the same netip.Addr but + // different ports). + // + // The value of 6 is chosen because we can advertise up to 3 endpoints + // based on the STUN IP: + // 1. The STUN endpoint itself (EndpointSTUN) + // 2. The STUN IP with the local Tailscale port (EndpointSTUN4LocalPort) + // 3. The STUN IP with a portmapped port (EndpointPortmapped) + // + // Storing 6 endpoints in the cache means we can store up to 2 previous + // sets of endpoints. + endpointTrackerMaxPerAddr = 6 +) + +// endpointTrackerEntry is an entry in an endpointHeap that stores the state of +// a given cached endpoint. +type endpointTrackerEntry struct { + // endpoint is the cached endpoint. + endpoint tailcfg.Endpoint + // until is the time until which this endpoint is being cached. + until time.Time + // index is the index within the containing endpointHeap. + index int +} + +// endpointHeap is an ordered heap of endpointTrackerEntry structs, ordered in +// ascending order by the 'until' expiry time (i.e. oldest first). +type endpointHeap []*endpointTrackerEntry + +var _ heap.Interface[*endpointTrackerEntry] = (*endpointHeap)(nil) + +// Len implements heap.Interface. +func (eh endpointHeap) Len() int { return len(eh) } + +// Less implements heap.Interface. +func (eh endpointHeap) Less(i, j int) bool { + // We want to store items so that the lowest item in the heap is the + // oldest, so that heap.Pop()-ing from the endpointHeap will remove the + // oldest entry. + return eh[i].until.Before(eh[j].until) +} + +// Swap implements heap.Interface. +func (eh endpointHeap) Swap(i, j int) { + eh[i], eh[j] = eh[j], eh[i] + eh[i].index = i + eh[j].index = j +} + +// Push implements heap.Interface. +func (eh *endpointHeap) Push(item *endpointTrackerEntry) { + n := len(*eh) + item.index = n + *eh = append(*eh, item) +} + +// Pop implements heap.Interface. +func (eh *endpointHeap) Pop() *endpointTrackerEntry { + old := *eh + n := len(old) + item := old[n-1] + old[n-1] = nil // avoid memory leak + item.index = -1 // for safety + *eh = old[0 : n-1] + return item +} + +// Min returns a pointer to the minimum element in the heap, without removing +// it. Since this is a min-heap ordered by the 'until' field, this returns the +// chronologically "earliest" element in the heap. +// +// Len() must be non-zero. +func (eh endpointHeap) Min() *endpointTrackerEntry { + return eh[0] +} + +// endpointTracker caches endpoints that are advertised to peers. This allows +// peers to still reach this node if there's a temporary endpoint flap; rather +// than withdrawing an endpoint and then re-advertising it the next time we run +// a netcheck, we keep advertising the endpoint until it's not present for a +// defined timeout. +// +// See tailscale/tailscale#7877 for more information. +type endpointTracker struct { + mu sync.Mutex + endpoints map[netip.Addr]*endpointHeap +} + +// update takes as input the current sent of discovered endpoints and the +// current time, and returns the set of endpoints plus any previous-cached and +// non-expired endpoints that should be advertised to peers. +func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) { + var inputEps set.Slice[netip.AddrPort] + for _, ep := range eps { + inputEps.Add(ep.Addr) + } + + et.mu.Lock() + defer et.mu.Unlock() + + // Extend endpoints that already exist in the cache. We do this before + // we remove expired endpoints, below, so we don't remove something + // that would otherwise have survived by extending. + until := now.Add(endpointTrackerLifetime) + for _, ep := range eps { + et.extendLocked(ep, until) + } + + // Now that we've extended existing endpoints, remove everything that + // has expired. + et.removeExpiredLocked(now) + + // Add entries from the input set of endpoints into the cache; we do + // this after removing expired ones so that we can store as many as + // possible, with space freed by the entries removed after expiry. + for _, ep := range eps { + et.addLocked(now, ep, until) + } + + // Finally, add entries to the return array that aren't already there. + epsPlusCached = eps + for _, heap := range et.endpoints { + for _, ep := range *heap { + // If the endpoint was in the input list, or has expired, skip it. + if inputEps.Contains(ep.endpoint.Addr) { + continue + } else if now.After(ep.until) { + // Defense-in-depth; should never happen since + // we removed expired entries above, but ignore + // it anyway. + continue + } + + // We haven't seen this endpoint; add to the return array + epsPlusCached = append(epsPlusCached, ep.endpoint) + } + } + + return epsPlusCached +} + +// extendLocked will update the expiry time of the provided endpoint in the +// cache, if it is present. If it is not present, nothing will be done. +// +// et.mu must be held. +func (et *endpointTracker) extendLocked(ep tailcfg.Endpoint, until time.Time) { + key := ep.Addr.Addr() + epHeap, found := et.endpoints[key] + if !found { + return + } + + // Find the entry for this exact address; this loop is quick since we + // bound the number of items in the heap. + // + // TODO(andrew): this means we iterate over the entire heap once per + // endpoint; even if the heap is small, if we have a lot of input + // endpoints this can be expensive? + for i, entry := range *epHeap { + if entry.endpoint == ep { + entry.until = until + heap.Fix(epHeap, i) + return + } + } +} + +// addLocked will store the provided endpoint(s) in the cache for a fixed +// period of time, ensuring that the size of the endpoint cache remains below +// the maximum. +// +// et.mu must be held. +func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) { + key := ep.Addr.Addr() + + // Create or get the heap for this endpoint's addr + epHeap := et.endpoints[key] + if epHeap == nil { + epHeap = new(endpointHeap) + mak.Set(&et.endpoints, key, epHeap) + } + + // Find the entry for this exact address; this loop is quick + // since we bound the number of items in the heap. + found := slices.ContainsFunc(*epHeap, func(v *endpointTrackerEntry) bool { + return v.endpoint == ep + }) + if !found { + // Add address to heap; either the endpoint is new, or the heap + // was newly-created and thus empty. + heap.Push(epHeap, &endpointTrackerEntry{endpoint: ep, until: until}) + } + + // Now that we've added everything, pop from our heap until we're below + // the limit. This is a min-heap, so popping removes the lowest (and + // thus oldest) endpoint. + for epHeap.Len() > endpointTrackerMaxPerAddr { + heap.Pop(epHeap) + } +} + +// removeExpired will remove all expired entries from the cache. +// +// et.mu must be held. +func (et *endpointTracker) removeExpiredLocked(now time.Time) { + for k, epHeap := range et.endpoints { + // The minimum element is oldest/earliest endpoint; repeatedly + // pop from the heap while it's in the past. + for epHeap.Len() > 0 { + minElem := epHeap.Min() + if now.After(minElem.until) { + heap.Pop(epHeap) + } else { + break + } + } + + if epHeap.Len() == 0 { + // Free up space in the map by removing the empty heap. + delete(et.endpoints, k) + } + } +} diff --git a/wgengine/magicsock/magicsock_unix_test.go b/wgengine/magicsock/magicsock_unix_test.go index 9ad8cab93..b0700a8eb 100644 --- a/wgengine/magicsock/magicsock_unix_test.go +++ b/wgengine/magicsock/magicsock_unix_test.go @@ -1,60 +1,60 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build unix - -package magicsock - -import ( - "net" - "syscall" - "testing" - - "tailscale.com/types/nettype" -) - -func TestTrySetSocketBuffer(t *testing.T) { - c, err := net.ListenPacket("udp", ":0") - if err != nil { - t.Fatal(err) - } - defer c.Close() - - rc, err := c.(*net.UDPConn).SyscallConn() - if err != nil { - t.Fatal(err) - } - - getBufs := func() (int, int) { - var rcv, snd int - rc.Control(func(fd uintptr) { - rcv, err = syscall.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUF) - if err != nil { - t.Errorf("getsockopt(SO_RCVBUF): %v", err) - } - snd, err = syscall.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_SNDBUF) - if err != nil { - t.Errorf("getsockopt(SO_SNDBUF): %v", err) - } - }) - return rcv, snd - } - - curRcv, curSnd := getBufs() - - trySetSocketBuffer(c.(nettype.PacketConn), t.Logf) - - newRcv, newSnd := getBufs() - - if curRcv > newRcv { - t.Errorf("SO_RCVBUF decreased: %v -> %v", curRcv, newRcv) - } - if curSnd > newSnd { - t.Errorf("SO_SNDBUF decreased: %v -> %v", curSnd, newSnd) - } - - // On many systems we may not increase the value, particularly running as a - // regular user, so log the information for manual verification. - t.Logf("SO_RCVBUF: %v -> %v", curRcv, newRcv) - t.Logf("SO_SNDBUF: %v -> %v", curRcv, newRcv) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build unix + +package magicsock + +import ( + "net" + "syscall" + "testing" + + "tailscale.com/types/nettype" +) + +func TestTrySetSocketBuffer(t *testing.T) { + c, err := net.ListenPacket("udp", ":0") + if err != nil { + t.Fatal(err) + } + defer c.Close() + + rc, err := c.(*net.UDPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + getBufs := func() (int, int) { + var rcv, snd int + rc.Control(func(fd uintptr) { + rcv, err = syscall.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUF) + if err != nil { + t.Errorf("getsockopt(SO_RCVBUF): %v", err) + } + snd, err = syscall.GetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_SNDBUF) + if err != nil { + t.Errorf("getsockopt(SO_SNDBUF): %v", err) + } + }) + return rcv, snd + } + + curRcv, curSnd := getBufs() + + trySetSocketBuffer(c.(nettype.PacketConn), t.Logf) + + newRcv, newSnd := getBufs() + + if curRcv > newRcv { + t.Errorf("SO_RCVBUF decreased: %v -> %v", curRcv, newRcv) + } + if curSnd > newSnd { + t.Errorf("SO_SNDBUF decreased: %v -> %v", curSnd, newSnd) + } + + // On many systems we may not increase the value, particularly running as a + // regular user, so log the information for manual verification. + t.Logf("SO_RCVBUF: %v -> %v", curRcv, newRcv) + t.Logf("SO_SNDBUF: %v -> %v", curRcv, newRcv) +} diff --git a/wgengine/magicsock/peermtu_darwin.go b/wgengine/magicsock/peermtu_darwin.go index b2a1ed217..a0a1aacb5 100644 --- a/wgengine/magicsock/peermtu_darwin.go +++ b/wgengine/magicsock/peermtu_darwin.go @@ -1,51 +1,51 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin && !ios - -package magicsock - -import ( - "syscall" - - "golang.org/x/sys/unix" -) - -func getDontFragOpt(network string) int { - if network == "udp4" { - return unix.IP_DONTFRAG - } - return unix.IPV6_DONTFRAG -} - -func (c *Conn) setDontFragment(network string, enable bool) error { - optArg := 1 - if enable == false { - optArg = 0 - } - var err error - rcErr := c.connControl(network, func(fd uintptr) { - err = syscall.SetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network), optArg) - }) - - if rcErr != nil { - return rcErr - } - return err -} - -func (c *Conn) getDontFragment(network string) (bool, error) { - var v int - var err error - rcErr := c.connControl(network, func(fd uintptr) { - v, err = syscall.GetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network)) - }) - - if rcErr != nil { - return false, rcErr - } - if v == 1 { - return true, err - } - return false, err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin && !ios + +package magicsock + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func getDontFragOpt(network string) int { + if network == "udp4" { + return unix.IP_DONTFRAG + } + return unix.IPV6_DONTFRAG +} + +func (c *Conn) setDontFragment(network string, enable bool) error { + optArg := 1 + if enable == false { + optArg = 0 + } + var err error + rcErr := c.connControl(network, func(fd uintptr) { + err = syscall.SetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network), optArg) + }) + + if rcErr != nil { + return rcErr + } + return err +} + +func (c *Conn) getDontFragment(network string) (bool, error) { + var v int + var err error + rcErr := c.connControl(network, func(fd uintptr) { + v, err = syscall.GetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network)) + }) + + if rcErr != nil { + return false, rcErr + } + if v == 1 { + return true, err + } + return false, err +} diff --git a/wgengine/magicsock/peermtu_linux.go b/wgengine/magicsock/peermtu_linux.go index d32ead099..b76f30f08 100644 --- a/wgengine/magicsock/peermtu_linux.go +++ b/wgengine/magicsock/peermtu_linux.go @@ -1,49 +1,49 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux && !android - -package magicsock - -import ( - "syscall" -) - -func getDontFragOpt(network string) int { - if network == "udp4" { - return syscall.IP_MTU_DISCOVER - } - return syscall.IPV6_MTU_DISCOVER -} - -func (c *Conn) setDontFragment(network string, enable bool) error { - optArg := syscall.IP_PMTUDISC_DO - if enable == false { - optArg = syscall.IP_PMTUDISC_DONT - } - var err error - rcErr := c.connControl(network, func(fd uintptr) { - err = syscall.SetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network), optArg) - }) - - if rcErr != nil { - return rcErr - } - return err -} - -func (c *Conn) getDontFragment(network string) (bool, error) { - var v int - var err error - rcErr := c.connControl(network, func(fd uintptr) { - v, err = syscall.GetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network)) - }) - - if rcErr != nil { - return false, rcErr - } - if v == syscall.IP_PMTUDISC_DO { - return true, err - } - return false, err -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !android + +package magicsock + +import ( + "syscall" +) + +func getDontFragOpt(network string) int { + if network == "udp4" { + return syscall.IP_MTU_DISCOVER + } + return syscall.IPV6_MTU_DISCOVER +} + +func (c *Conn) setDontFragment(network string, enable bool) error { + optArg := syscall.IP_PMTUDISC_DO + if enable == false { + optArg = syscall.IP_PMTUDISC_DONT + } + var err error + rcErr := c.connControl(network, func(fd uintptr) { + err = syscall.SetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network), optArg) + }) + + if rcErr != nil { + return rcErr + } + return err +} + +func (c *Conn) getDontFragment(network string) (bool, error) { + var v int + var err error + rcErr := c.connControl(network, func(fd uintptr) { + v, err = syscall.GetsockoptInt(int(fd), getIPProto(network), getDontFragOpt(network)) + }) + + if rcErr != nil { + return false, rcErr + } + if v == syscall.IP_PMTUDISC_DO { + return true, err + } + return false, err +} diff --git a/wgengine/magicsock/peermtu_unix.go b/wgengine/magicsock/peermtu_unix.go index 59e808ee7..eec3d744f 100644 --- a/wgengine/magicsock/peermtu_unix.go +++ b/wgengine/magicsock/peermtu_unix.go @@ -1,42 +1,42 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build (darwin && !ios) || (linux && !android) - -package magicsock - -import ( - "syscall" -) - -// getIPProto returns the value of the get/setsockopt proto argument necessary -// to set an IP sockopt that corresponds with the string network, which must be -// "udp4" or "udp6". -func getIPProto(network string) int { - if network == "udp4" { - return syscall.IPPROTO_IP - } - return syscall.IPPROTO_IPV6 -} - -// connControl allows the caller to run a system call on the socket underlying -// Conn specified by the string network, which must be "udp4" or "udp6". If the -// pconn type implements the syscall method, this function returns the value of -// of the system call fn called with the fd of the socket as its arg (or the -// error from rc.Control() if that fails). Otherwise it returns the error -// errUnsupportedConnType. -func (c *Conn) connControl(network string, fn func(fd uintptr)) error { - pconn := c.pconn4.pconn - if network == "udp6" { - pconn = c.pconn6.pconn - } - sc, ok := pconn.(syscall.Conn) - if !ok { - return errUnsupportedConnType - } - rc, err := sc.SyscallConn() - if err != nil { - return err - } - return rc.Control(fn) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (darwin && !ios) || (linux && !android) + +package magicsock + +import ( + "syscall" +) + +// getIPProto returns the value of the get/setsockopt proto argument necessary +// to set an IP sockopt that corresponds with the string network, which must be +// "udp4" or "udp6". +func getIPProto(network string) int { + if network == "udp4" { + return syscall.IPPROTO_IP + } + return syscall.IPPROTO_IPV6 +} + +// connControl allows the caller to run a system call on the socket underlying +// Conn specified by the string network, which must be "udp4" or "udp6". If the +// pconn type implements the syscall method, this function returns the value of +// of the system call fn called with the fd of the socket as its arg (or the +// error from rc.Control() if that fails). Otherwise it returns the error +// errUnsupportedConnType. +func (c *Conn) connControl(network string, fn func(fd uintptr)) error { + pconn := c.pconn4.pconn + if network == "udp6" { + pconn = c.pconn6.pconn + } + sc, ok := pconn.(syscall.Conn) + if !ok { + return errUnsupportedConnType + } + rc, err := sc.SyscallConn() + if err != nil { + return err + } + return rc.Control(fn) +} diff --git a/wgengine/mem_ios.go b/wgengine/mem_ios.go index 975dfca61..cc266ea3a 100644 --- a/wgengine/mem_ios.go +++ b/wgengine/mem_ios.go @@ -1,20 +1,20 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package wgengine - -import ( - "github.com/tailscale/wireguard-go/device" -) - -// iOS has a very restrictive memory limit on network extensions. -// Reduce the maximum amount of memory that wireguard-go can allocate -// to avoid getting killed. - -func init() { - device.QueueStagedSize = 64 - device.QueueOutboundSize = 64 - device.QueueInboundSize = 64 - device.QueueHandshakeSize = 64 - device.PreallocatedBuffersPerPool = 64 -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package wgengine + +import ( + "github.com/tailscale/wireguard-go/device" +) + +// iOS has a very restrictive memory limit on network extensions. +// Reduce the maximum amount of memory that wireguard-go can allocate +// to avoid getting killed. + +func init() { + device.QueueStagedSize = 64 + device.QueueOutboundSize = 64 + device.QueueInboundSize = 64 + device.QueueHandshakeSize = 64 + device.PreallocatedBuffersPerPool = 64 +} diff --git a/wgengine/netstack/netstack_linux.go b/wgengine/netstack/netstack_linux.go index 9e27b7819..a0bfb4456 100644 --- a/wgengine/netstack/netstack_linux.go +++ b/wgengine/netstack/netstack_linux.go @@ -1,19 +1,19 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package netstack - -import ( - "os/exec" - "syscall" - - "golang.org/x/sys/unix" -) - -func init() { - setAmbientCapsRaw = func(cmd *exec.Cmd) { - cmd.SysProcAttr = &syscall.SysProcAttr{ - AmbientCaps: []uintptr{unix.CAP_NET_RAW}, - } - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netstack + +import ( + "os/exec" + "syscall" + + "golang.org/x/sys/unix" +) + +func init() { + setAmbientCapsRaw = func(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + AmbientCaps: []uintptr{unix.CAP_NET_RAW}, + } + } +} diff --git a/wgengine/router/runner.go b/wgengine/router/runner.go index 7ba633344..8fa068e33 100644 --- a/wgengine/router/runner.go +++ b/wgengine/router/runner.go @@ -1,120 +1,120 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package router - -import ( - "errors" - "fmt" - "os" - "os/exec" - "strconv" - "strings" - "syscall" - - "golang.org/x/sys/unix" -) - -// commandRunner abstracts helpers to run OS commands. It exists -// purely to swap out osCommandRunner (below) with a fake runner in -// tests. -type commandRunner interface { - run(...string) error - output(...string) ([]byte, error) -} - -type osCommandRunner struct { - // ambientCapNetAdmin determines whether commands are executed with - // CAP_NET_ADMIN. - // CAP_NET_ADMIN is required when running as non-root and executing cmds - // like `ip rule`. Even if our process has the capability, we need to - // explicitly grant it to the new process. - // We specifically need this for Synology DSM7 where tailscaled no longer - // runs as root. - ambientCapNetAdmin bool -} - -// errCode extracts and returns the process exit code from err, or -// zero if err is nil. -func errCode(err error) int { - if err == nil { - return 0 - } - var e *exec.ExitError - if ok := errors.As(err, &e); ok { - return e.ExitCode() - } - s := err.Error() - if strings.HasPrefix(s, "exitcode:") { - code, err := strconv.Atoi(s[9:]) - if err == nil { - return code - } - } - return -42 -} - -func (o osCommandRunner) run(args ...string) error { - _, err := o.output(args...) - return err -} - -func (o osCommandRunner) output(args ...string) ([]byte, error) { - if len(args) == 0 { - return nil, errors.New("cmd: no argv[0]") - } - - cmd := exec.Command(args[0], args[1:]...) - cmd.Env = append(os.Environ(), "LC_ALL=C") - if o.ambientCapNetAdmin { - cmd.SysProcAttr = &syscall.SysProcAttr{ - AmbientCaps: []uintptr{unix.CAP_NET_ADMIN}, - } - } - out, err := cmd.CombinedOutput() - if err != nil { - return nil, fmt.Errorf("running %q failed: %w\n%s", strings.Join(args, " "), err, out) - } - - return out, nil -} - -type runGroup struct { - OkCode []int // error codes that are acceptable, other than 0, if any - Runner commandRunner // the runner that actually runs our commands - ErrAcc error // first error encountered, if any -} - -func newRunGroup(okCode []int, runner commandRunner) *runGroup { - return &runGroup{ - OkCode: okCode, - Runner: runner, - } -} - -func (rg *runGroup) okCode(err error) bool { - got := errCode(err) - for _, want := range rg.OkCode { - if got == want { - return true - } - } - return false -} - -func (rg *runGroup) Output(args ...string) []byte { - b, err := rg.Runner.output(args...) - if rg.ErrAcc == nil && err != nil && !rg.okCode(err) { - rg.ErrAcc = err - } - return b -} - -func (rg *runGroup) Run(args ...string) { - err := rg.Runner.run(args...) - if rg.ErrAcc == nil && err != nil && !rg.okCode(err) { - rg.ErrAcc = err - } -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package router + +import ( + "errors" + "fmt" + "os" + "os/exec" + "strconv" + "strings" + "syscall" + + "golang.org/x/sys/unix" +) + +// commandRunner abstracts helpers to run OS commands. It exists +// purely to swap out osCommandRunner (below) with a fake runner in +// tests. +type commandRunner interface { + run(...string) error + output(...string) ([]byte, error) +} + +type osCommandRunner struct { + // ambientCapNetAdmin determines whether commands are executed with + // CAP_NET_ADMIN. + // CAP_NET_ADMIN is required when running as non-root and executing cmds + // like `ip rule`. Even if our process has the capability, we need to + // explicitly grant it to the new process. + // We specifically need this for Synology DSM7 where tailscaled no longer + // runs as root. + ambientCapNetAdmin bool +} + +// errCode extracts and returns the process exit code from err, or +// zero if err is nil. +func errCode(err error) int { + if err == nil { + return 0 + } + var e *exec.ExitError + if ok := errors.As(err, &e); ok { + return e.ExitCode() + } + s := err.Error() + if strings.HasPrefix(s, "exitcode:") { + code, err := strconv.Atoi(s[9:]) + if err == nil { + return code + } + } + return -42 +} + +func (o osCommandRunner) run(args ...string) error { + _, err := o.output(args...) + return err +} + +func (o osCommandRunner) output(args ...string) ([]byte, error) { + if len(args) == 0 { + return nil, errors.New("cmd: no argv[0]") + } + + cmd := exec.Command(args[0], args[1:]...) + cmd.Env = append(os.Environ(), "LC_ALL=C") + if o.ambientCapNetAdmin { + cmd.SysProcAttr = &syscall.SysProcAttr{ + AmbientCaps: []uintptr{unix.CAP_NET_ADMIN}, + } + } + out, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("running %q failed: %w\n%s", strings.Join(args, " "), err, out) + } + + return out, nil +} + +type runGroup struct { + OkCode []int // error codes that are acceptable, other than 0, if any + Runner commandRunner // the runner that actually runs our commands + ErrAcc error // first error encountered, if any +} + +func newRunGroup(okCode []int, runner commandRunner) *runGroup { + return &runGroup{ + OkCode: okCode, + Runner: runner, + } +} + +func (rg *runGroup) okCode(err error) bool { + got := errCode(err) + for _, want := range rg.OkCode { + if got == want { + return true + } + } + return false +} + +func (rg *runGroup) Output(args ...string) []byte { + b, err := rg.Runner.output(args...) + if rg.ErrAcc == nil && err != nil && !rg.okCode(err) { + rg.ErrAcc = err + } + return b +} + +func (rg *runGroup) Run(args ...string) { + err := rg.Runner.run(args...) + if rg.ErrAcc == nil && err != nil && !rg.okCode(err) { + rg.ErrAcc = err + } +} diff --git a/wgengine/watchdog_js.go b/wgengine/watchdog_js.go index 9dcb29c4e..872ce36d5 100644 --- a/wgengine/watchdog_js.go +++ b/wgengine/watchdog_js.go @@ -1,17 +1,17 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build js - -package wgengine - -import "tailscale.com/net/dns/resolver" - -type watchdogEngine struct { - Engine - wrap Engine -} - -func (e *watchdogEngine) GetResolver() (r *resolver.Resolver, ok bool) { - return nil, false -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build js + +package wgengine + +import "tailscale.com/net/dns/resolver" + +type watchdogEngine struct { + Engine + wrap Engine +} + +func (e *watchdogEngine) GetResolver() (r *resolver.Resolver, ok bool) { + return nil, false +} diff --git a/wgengine/wgcfg/device.go b/wgengine/wgcfg/device.go index 9b83998cb..80fa159e3 100644 --- a/wgengine/wgcfg/device.go +++ b/wgengine/wgcfg/device.go @@ -1,68 +1,68 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "io" - "sort" - - "github.com/tailscale/wireguard-go/conn" - "github.com/tailscale/wireguard-go/device" - "github.com/tailscale/wireguard-go/tun" - "tailscale.com/types/logger" - "tailscale.com/util/multierr" -) - -// NewDevice returns a wireguard-go Device configured for Tailscale use. -func NewDevice(tunDev tun.Device, bind conn.Bind, logger *device.Logger) *device.Device { - ret := device.NewDevice(tunDev, bind, logger) - ret.DisableSomeRoamingForBrokenMobileSemantics() - return ret -} - -func DeviceConfig(d *device.Device) (*Config, error) { - r, w := io.Pipe() - errc := make(chan error, 1) - go func() { - errc <- d.IpcGetOperation(w) - w.Close() - }() - cfg, fromErr := FromUAPI(r) - r.Close() - getErr := <-errc - err := multierr.New(getErr, fromErr) - if err != nil { - return nil, err - } - sort.Slice(cfg.Peers, func(i, j int) bool { - return cfg.Peers[i].PublicKey.Less(cfg.Peers[j].PublicKey) - }) - return cfg, nil -} - -// ReconfigDevice replaces the existing device configuration with cfg. -func ReconfigDevice(d *device.Device, cfg *Config, logf logger.Logf) (err error) { - defer func() { - if err != nil { - logf("wgcfg.Reconfig failed: %v", err) - } - }() - - prev, err := DeviceConfig(d) - if err != nil { - return err - } - - r, w := io.Pipe() - errc := make(chan error, 1) - go func() { - errc <- d.IpcSetOperation(r) - r.Close() - }() - - toErr := cfg.ToUAPI(logf, w, prev) - w.Close() - setErr := <-errc - return multierr.New(setErr, toErr) -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package wgcfg + +import ( + "io" + "sort" + + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "tailscale.com/types/logger" + "tailscale.com/util/multierr" +) + +// NewDevice returns a wireguard-go Device configured for Tailscale use. +func NewDevice(tunDev tun.Device, bind conn.Bind, logger *device.Logger) *device.Device { + ret := device.NewDevice(tunDev, bind, logger) + ret.DisableSomeRoamingForBrokenMobileSemantics() + return ret +} + +func DeviceConfig(d *device.Device) (*Config, error) { + r, w := io.Pipe() + errc := make(chan error, 1) + go func() { + errc <- d.IpcGetOperation(w) + w.Close() + }() + cfg, fromErr := FromUAPI(r) + r.Close() + getErr := <-errc + err := multierr.New(getErr, fromErr) + if err != nil { + return nil, err + } + sort.Slice(cfg.Peers, func(i, j int) bool { + return cfg.Peers[i].PublicKey.Less(cfg.Peers[j].PublicKey) + }) + return cfg, nil +} + +// ReconfigDevice replaces the existing device configuration with cfg. +func ReconfigDevice(d *device.Device, cfg *Config, logf logger.Logf) (err error) { + defer func() { + if err != nil { + logf("wgcfg.Reconfig failed: %v", err) + } + }() + + prev, err := DeviceConfig(d) + if err != nil { + return err + } + + r, w := io.Pipe() + errc := make(chan error, 1) + go func() { + errc <- d.IpcSetOperation(r) + r.Close() + }() + + toErr := cfg.ToUAPI(logf, w, prev) + w.Close() + setErr := <-errc + return multierr.New(setErr, toErr) +} diff --git a/wgengine/wgcfg/device_test.go b/wgengine/wgcfg/device_test.go index c54ad16d9..d54282e4b 100644 --- a/wgengine/wgcfg/device_test.go +++ b/wgengine/wgcfg/device_test.go @@ -1,261 +1,261 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "bufio" - "bytes" - "io" - "net/netip" - "os" - "sort" - "strings" - "sync" - "testing" - - "github.com/tailscale/wireguard-go/conn" - "github.com/tailscale/wireguard-go/device" - "github.com/tailscale/wireguard-go/tun" - "go4.org/mem" - "tailscale.com/types/key" -) - -func TestDeviceConfig(t *testing.T) { - newK := func() (key.NodePublic, key.NodePrivate) { - t.Helper() - k := key.NewNode() - return k.Public(), k - } - k1, pk1 := newK() - ip1 := netip.MustParsePrefix("10.0.0.1/32") - - k2, pk2 := newK() - ip2 := netip.MustParsePrefix("10.0.0.2/32") - - k3, _ := newK() - ip3 := netip.MustParsePrefix("10.0.0.3/32") - - cfg1 := &Config{ - PrivateKey: pk1, - Peers: []Peer{{ - PublicKey: k2, - AllowedIPs: []netip.Prefix{ip2}, - }}, - } - - cfg2 := &Config{ - PrivateKey: pk2, - Peers: []Peer{{ - PublicKey: k1, - AllowedIPs: []netip.Prefix{ip1}, - PersistentKeepalive: 5, - }}, - } - - device1 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device1")) - device2 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device2")) - defer device1.Close() - defer device2.Close() - - cmp := func(t *testing.T, d *device.Device, want *Config) { - t.Helper() - got, err := DeviceConfig(d) - if err != nil { - t.Fatal(err) - } - prev := new(Config) - gotbuf := new(strings.Builder) - err = got.ToUAPI(t.Logf, gotbuf, prev) - gotStr := gotbuf.String() - if err != nil { - t.Errorf("got.ToUAPI(): error: %v", err) - return - } - wantbuf := new(strings.Builder) - err = want.ToUAPI(t.Logf, wantbuf, prev) - wantStr := wantbuf.String() - if err != nil { - t.Errorf("want.ToUAPI(): error: %v", err) - return - } - if gotStr != wantStr { - buf := new(bytes.Buffer) - w := bufio.NewWriter(buf) - if err := d.IpcGetOperation(w); err != nil { - t.Errorf("on error, could not IpcGetOperation: %v", err) - } - w.Flush() - t.Errorf("config mismatch:\n---- got:\n%s\n---- want:\n%s\n---- uapi:\n%s", gotStr, wantStr, buf.String()) - } - } - - t.Run("device1 config", func(t *testing.T) { - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - }) - - t.Run("device2 config", func(t *testing.T) { - if err := ReconfigDevice(device2, cfg2, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device2, cfg2) - }) - - // This is only to test that Config and Reconfig are properly synchronized. - t.Run("device2 config/reconfig", func(t *testing.T) { - var wg sync.WaitGroup - wg.Add(2) - - go func() { - ReconfigDevice(device2, cfg2, t.Logf) - wg.Done() - }() - - go func() { - DeviceConfig(device2) - wg.Done() - }() - - wg.Wait() - }) - - t.Run("device1 modify peer", func(t *testing.T) { - cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 1, 31: 0})) - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - }) - - t.Run("device1 replace endpoint", func(t *testing.T) { - cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 2, 31: 0})) - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - }) - - t.Run("device1 add new peer", func(t *testing.T) { - cfg1.Peers = append(cfg1.Peers, Peer{ - PublicKey: k3, - AllowedIPs: []netip.Prefix{ip3}, - }) - sort.Slice(cfg1.Peers, func(i, j int) bool { - return cfg1.Peers[i].PublicKey.Less(cfg1.Peers[j].PublicKey) - }) - - origCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) - } - - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - - newCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) - } - - peer0 := func(cfg *Config) Peer { - p, ok := cfg.PeerWithKey(k2) - if !ok { - t.Helper() - t.Fatal("failed to look up peer 2") - } - return p - } - peersEqual := func(p, q Peer) bool { - return p.PublicKey == q.PublicKey && p.DiscoKey == q.DiscoKey && p.PersistentKeepalive == q.PersistentKeepalive && cidrsEqual(p.AllowedIPs, q.AllowedIPs) - } - if !peersEqual(peer0(origCfg), peer0(newCfg)) { - t.Error("reconfig modified old peer") - } - }) - - t.Run("device1 remove peer", func(t *testing.T) { - removeKey := cfg1.Peers[len(cfg1.Peers)-1].PublicKey - cfg1.Peers = cfg1.Peers[:len(cfg1.Peers)-1] - - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - - newCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) - } - - _, ok := newCfg.PeerWithKey(removeKey) - if ok { - t.Error("reconfig failed to remove peer") - } - }) -} - -// TODO: replace with a loopback tunnel -type nilTun struct { - events chan tun.Event - closed chan struct{} -} - -func newNilTun() tun.Device { - return &nilTun{ - events: make(chan tun.Event), - closed: make(chan struct{}), - } -} - -func (t *nilTun) File() *os.File { return nil } -func (t *nilTun) Flush() error { return nil } -func (t *nilTun) MTU() (int, error) { return 1420, nil } -func (t *nilTun) Name() (string, error) { return "niltun", nil } -func (t *nilTun) Events() <-chan tun.Event { return t.events } - -func (t *nilTun) Read(data [][]byte, sizes []int, offset int) (int, error) { - <-t.closed - return 0, io.EOF -} - -func (t *nilTun) Write(data [][]byte, offset int) (int, error) { - <-t.closed - return 0, io.EOF -} - -func (t *nilTun) Close() error { - close(t.events) - close(t.closed) - return nil -} - -func (t *nilTun) BatchSize() int { return 1 } - -// A noopBind is a conn.Bind that does no actual binding work. -type noopBind struct{} - -func (noopBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { - return nil, 1, nil -} -func (noopBind) Close() error { return nil } -func (noopBind) SetMark(mark uint32) error { return nil } -func (noopBind) Send(b [][]byte, ep conn.Endpoint) error { return nil } -func (noopBind) ParseEndpoint(s string) (conn.Endpoint, error) { - return dummyEndpoint(s), nil -} -func (noopBind) BatchSize() int { return 1 } - -// A dummyEndpoint is a string holding the endpoint destination. -type dummyEndpoint string - -func (e dummyEndpoint) ClearSrc() {} -func (e dummyEndpoint) SrcToString() string { return "" } -func (e dummyEndpoint) DstToString() string { return string(e) } -func (e dummyEndpoint) DstToBytes() []byte { return nil } -func (e dummyEndpoint) DstIP() netip.Addr { return netip.Addr{} } -func (dummyEndpoint) SrcIP() netip.Addr { return netip.Addr{} } +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package wgcfg + +import ( + "bufio" + "bytes" + "io" + "net/netip" + "os" + "sort" + "strings" + "sync" + "testing" + + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "go4.org/mem" + "tailscale.com/types/key" +) + +func TestDeviceConfig(t *testing.T) { + newK := func() (key.NodePublic, key.NodePrivate) { + t.Helper() + k := key.NewNode() + return k.Public(), k + } + k1, pk1 := newK() + ip1 := netip.MustParsePrefix("10.0.0.1/32") + + k2, pk2 := newK() + ip2 := netip.MustParsePrefix("10.0.0.2/32") + + k3, _ := newK() + ip3 := netip.MustParsePrefix("10.0.0.3/32") + + cfg1 := &Config{ + PrivateKey: pk1, + Peers: []Peer{{ + PublicKey: k2, + AllowedIPs: []netip.Prefix{ip2}, + }}, + } + + cfg2 := &Config{ + PrivateKey: pk2, + Peers: []Peer{{ + PublicKey: k1, + AllowedIPs: []netip.Prefix{ip1}, + PersistentKeepalive: 5, + }}, + } + + device1 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device1")) + device2 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device2")) + defer device1.Close() + defer device2.Close() + + cmp := func(t *testing.T, d *device.Device, want *Config) { + t.Helper() + got, err := DeviceConfig(d) + if err != nil { + t.Fatal(err) + } + prev := new(Config) + gotbuf := new(strings.Builder) + err = got.ToUAPI(t.Logf, gotbuf, prev) + gotStr := gotbuf.String() + if err != nil { + t.Errorf("got.ToUAPI(): error: %v", err) + return + } + wantbuf := new(strings.Builder) + err = want.ToUAPI(t.Logf, wantbuf, prev) + wantStr := wantbuf.String() + if err != nil { + t.Errorf("want.ToUAPI(): error: %v", err) + return + } + if gotStr != wantStr { + buf := new(bytes.Buffer) + w := bufio.NewWriter(buf) + if err := d.IpcGetOperation(w); err != nil { + t.Errorf("on error, could not IpcGetOperation: %v", err) + } + w.Flush() + t.Errorf("config mismatch:\n---- got:\n%s\n---- want:\n%s\n---- uapi:\n%s", gotStr, wantStr, buf.String()) + } + } + + t.Run("device1 config", func(t *testing.T) { + if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { + t.Fatal(err) + } + cmp(t, device1, cfg1) + }) + + t.Run("device2 config", func(t *testing.T) { + if err := ReconfigDevice(device2, cfg2, t.Logf); err != nil { + t.Fatal(err) + } + cmp(t, device2, cfg2) + }) + + // This is only to test that Config and Reconfig are properly synchronized. + t.Run("device2 config/reconfig", func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(2) + + go func() { + ReconfigDevice(device2, cfg2, t.Logf) + wg.Done() + }() + + go func() { + DeviceConfig(device2) + wg.Done() + }() + + wg.Wait() + }) + + t.Run("device1 modify peer", func(t *testing.T) { + cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 1, 31: 0})) + if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { + t.Fatal(err) + } + cmp(t, device1, cfg1) + }) + + t.Run("device1 replace endpoint", func(t *testing.T) { + cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 2, 31: 0})) + if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { + t.Fatal(err) + } + cmp(t, device1, cfg1) + }) + + t.Run("device1 add new peer", func(t *testing.T) { + cfg1.Peers = append(cfg1.Peers, Peer{ + PublicKey: k3, + AllowedIPs: []netip.Prefix{ip3}, + }) + sort.Slice(cfg1.Peers, func(i, j int) bool { + return cfg1.Peers[i].PublicKey.Less(cfg1.Peers[j].PublicKey) + }) + + origCfg, err := DeviceConfig(device1) + if err != nil { + t.Fatal(err) + } + + if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { + t.Fatal(err) + } + cmp(t, device1, cfg1) + + newCfg, err := DeviceConfig(device1) + if err != nil { + t.Fatal(err) + } + + peer0 := func(cfg *Config) Peer { + p, ok := cfg.PeerWithKey(k2) + if !ok { + t.Helper() + t.Fatal("failed to look up peer 2") + } + return p + } + peersEqual := func(p, q Peer) bool { + return p.PublicKey == q.PublicKey && p.DiscoKey == q.DiscoKey && p.PersistentKeepalive == q.PersistentKeepalive && cidrsEqual(p.AllowedIPs, q.AllowedIPs) + } + if !peersEqual(peer0(origCfg), peer0(newCfg)) { + t.Error("reconfig modified old peer") + } + }) + + t.Run("device1 remove peer", func(t *testing.T) { + removeKey := cfg1.Peers[len(cfg1.Peers)-1].PublicKey + cfg1.Peers = cfg1.Peers[:len(cfg1.Peers)-1] + + if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { + t.Fatal(err) + } + cmp(t, device1, cfg1) + + newCfg, err := DeviceConfig(device1) + if err != nil { + t.Fatal(err) + } + + _, ok := newCfg.PeerWithKey(removeKey) + if ok { + t.Error("reconfig failed to remove peer") + } + }) +} + +// TODO: replace with a loopback tunnel +type nilTun struct { + events chan tun.Event + closed chan struct{} +} + +func newNilTun() tun.Device { + return &nilTun{ + events: make(chan tun.Event), + closed: make(chan struct{}), + } +} + +func (t *nilTun) File() *os.File { return nil } +func (t *nilTun) Flush() error { return nil } +func (t *nilTun) MTU() (int, error) { return 1420, nil } +func (t *nilTun) Name() (string, error) { return "niltun", nil } +func (t *nilTun) Events() <-chan tun.Event { return t.events } + +func (t *nilTun) Read(data [][]byte, sizes []int, offset int) (int, error) { + <-t.closed + return 0, io.EOF +} + +func (t *nilTun) Write(data [][]byte, offset int) (int, error) { + <-t.closed + return 0, io.EOF +} + +func (t *nilTun) Close() error { + close(t.events) + close(t.closed) + return nil +} + +func (t *nilTun) BatchSize() int { return 1 } + +// A noopBind is a conn.Bind that does no actual binding work. +type noopBind struct{} + +func (noopBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { + return nil, 1, nil +} +func (noopBind) Close() error { return nil } +func (noopBind) SetMark(mark uint32) error { return nil } +func (noopBind) Send(b [][]byte, ep conn.Endpoint) error { return nil } +func (noopBind) ParseEndpoint(s string) (conn.Endpoint, error) { + return dummyEndpoint(s), nil +} +func (noopBind) BatchSize() int { return 1 } + +// A dummyEndpoint is a string holding the endpoint destination. +type dummyEndpoint string + +func (e dummyEndpoint) ClearSrc() {} +func (e dummyEndpoint) SrcToString() string { return "" } +func (e dummyEndpoint) DstToString() string { return string(e) } +func (e dummyEndpoint) DstToBytes() []byte { return nil } +func (e dummyEndpoint) DstIP() netip.Addr { return netip.Addr{} } +func (dummyEndpoint) SrcIP() netip.Addr { return netip.Addr{} } diff --git a/wgengine/wgcfg/parser.go b/wgengine/wgcfg/parser.go index 553aaecbb..ec3d008f7 100644 --- a/wgengine/wgcfg/parser.go +++ b/wgengine/wgcfg/parser.go @@ -1,186 +1,186 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "bufio" - "fmt" - "io" - "net" - "net/netip" - "strconv" - "strings" - - "go4.org/mem" - "tailscale.com/types/key" -) - -type ParseError struct { - why string - offender string -} - -func (e *ParseError) Error() string { - return fmt.Sprintf("%s: %q", e.why, e.offender) -} - -func parseEndpoint(s string) (host string, port uint16, err error) { - i := strings.LastIndexByte(s, ':') - if i < 0 { - return "", 0, &ParseError{"Missing port from endpoint", s} - } - host, portStr := s[:i], s[i+1:] - if len(host) < 1 { - return "", 0, &ParseError{"Invalid endpoint host", host} - } - uport, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - return "", 0, err - } - hostColon := strings.IndexByte(host, ':') - if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 { - err := &ParseError{"Brackets must contain an IPv6 address", host} - if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 { - maybeV6 := net.ParseIP(host[1 : len(host)-1]) - if maybeV6 == nil || len(maybeV6) != net.IPv6len { - return "", 0, err - } - } else { - return "", 0, err - } - host = host[1 : len(host)-1] - } - return host, uint16(uport), nil -} - -// memROCut separates a mem.RO at the separator if it exists, otherwise -// it returns two empty ROs and reports that it was not found. -func memROCut(s mem.RO, sep byte) (before, after mem.RO, found bool) { - if i := mem.IndexByte(s, sep); i >= 0 { - return s.SliceTo(i), s.SliceFrom(i + 1), true - } - found = false - return -} - -// FromUAPI generates a Config from r. -// r should be generated by calling device.IpcGetOperation; -// it is not compatible with other uapi streams. -func FromUAPI(r io.Reader) (*Config, error) { - cfg := new(Config) - var peer *Peer // current peer being operated on - deviceConfig := true - - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := mem.B(scanner.Bytes()) - if line.Len() == 0 { - continue - } - key, value, ok := memROCut(line, '=') - if !ok { - return nil, fmt.Errorf("failed to cut line %q on =", line.StringCopy()) - } - valueBytes := scanner.Bytes()[key.Len()+1:] - - if key.EqualString("public_key") { - if deviceConfig { - deviceConfig = false - } - // Load/create the peer we are now configuring. - var err error - peer, err = cfg.handlePublicKeyLine(valueBytes) - if err != nil { - return nil, err - } - continue - } - - var err error - if deviceConfig { - err = cfg.handleDeviceLine(key, value, valueBytes) - } else { - err = cfg.handlePeerLine(peer, key, value, valueBytes) - } - if err != nil { - return nil, err - } - } - - if err := scanner.Err(); err != nil { - return nil, err - } - - return cfg, nil -} - -func (cfg *Config) handleDeviceLine(k, value mem.RO, valueBytes []byte) error { - switch { - case k.EqualString("private_key"): - // wireguard-go guarantees not to send zero value; private keys are already clamped. - var err error - cfg.PrivateKey, err = key.ParseNodePrivateUntyped(value) - if err != nil { - return err - } - case k.EqualString("listen_port") || k.EqualString("fwmark"): - // ignore - default: - return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy()) - } - return nil -} - -func (cfg *Config) handlePublicKeyLine(valueBytes []byte) (*Peer, error) { - p := Peer{} - var err error - p.PublicKey, err = key.ParseNodePublicUntyped(mem.B(valueBytes)) - if err != nil { - return nil, err - } - cfg.Peers = append(cfg.Peers, p) - return &cfg.Peers[len(cfg.Peers)-1], nil -} - -func (cfg *Config) handlePeerLine(peer *Peer, k, value mem.RO, valueBytes []byte) error { - switch { - case k.EqualString("endpoint"): - nk, err := key.ParseNodePublicUntyped(value) - if err != nil { - return fmt.Errorf("invalid endpoint %q for peer %q, expected a hex public key", value.StringCopy(), peer.PublicKey.ShortString()) - } - // nk ought to equal peer.PublicKey. - // Under some rare circumstances, it might not. See corp issue #3016. - // Even if that happens, don't stop early, so that we can recover from it. - // Instead, note the value of nk so we can fix as needed. - peer.WGEndpoint = nk - case k.EqualString("persistent_keepalive_interval"): - n, err := mem.ParseUint(value, 10, 16) - if err != nil { - return err - } - peer.PersistentKeepalive = uint16(n) - case k.EqualString("allowed_ip"): - ipp := netip.Prefix{} - err := ipp.UnmarshalText(valueBytes) - if err != nil { - return err - } - peer.AllowedIPs = append(peer.AllowedIPs, ipp) - case k.EqualString("protocol_version"): - if !value.EqualString("1") { - return fmt.Errorf("invalid protocol version: %q", value.StringCopy()) - } - case k.EqualString("replace_allowed_ips") || - k.EqualString("preshared_key") || - k.EqualString("last_handshake_time_sec") || - k.EqualString("last_handshake_time_nsec") || - k.EqualString("tx_bytes") || - k.EqualString("rx_bytes"): - // ignore - default: - return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy()) - } - return nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package wgcfg + +import ( + "bufio" + "fmt" + "io" + "net" + "net/netip" + "strconv" + "strings" + + "go4.org/mem" + "tailscale.com/types/key" +) + +type ParseError struct { + why string + offender string +} + +func (e *ParseError) Error() string { + return fmt.Sprintf("%s: %q", e.why, e.offender) +} + +func parseEndpoint(s string) (host string, port uint16, err error) { + i := strings.LastIndexByte(s, ':') + if i < 0 { + return "", 0, &ParseError{"Missing port from endpoint", s} + } + host, portStr := s[:i], s[i+1:] + if len(host) < 1 { + return "", 0, &ParseError{"Invalid endpoint host", host} + } + uport, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return "", 0, err + } + hostColon := strings.IndexByte(host, ':') + if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 { + err := &ParseError{"Brackets must contain an IPv6 address", host} + if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 { + maybeV6 := net.ParseIP(host[1 : len(host)-1]) + if maybeV6 == nil || len(maybeV6) != net.IPv6len { + return "", 0, err + } + } else { + return "", 0, err + } + host = host[1 : len(host)-1] + } + return host, uint16(uport), nil +} + +// memROCut separates a mem.RO at the separator if it exists, otherwise +// it returns two empty ROs and reports that it was not found. +func memROCut(s mem.RO, sep byte) (before, after mem.RO, found bool) { + if i := mem.IndexByte(s, sep); i >= 0 { + return s.SliceTo(i), s.SliceFrom(i + 1), true + } + found = false + return +} + +// FromUAPI generates a Config from r. +// r should be generated by calling device.IpcGetOperation; +// it is not compatible with other uapi streams. +func FromUAPI(r io.Reader) (*Config, error) { + cfg := new(Config) + var peer *Peer // current peer being operated on + deviceConfig := true + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := mem.B(scanner.Bytes()) + if line.Len() == 0 { + continue + } + key, value, ok := memROCut(line, '=') + if !ok { + return nil, fmt.Errorf("failed to cut line %q on =", line.StringCopy()) + } + valueBytes := scanner.Bytes()[key.Len()+1:] + + if key.EqualString("public_key") { + if deviceConfig { + deviceConfig = false + } + // Load/create the peer we are now configuring. + var err error + peer, err = cfg.handlePublicKeyLine(valueBytes) + if err != nil { + return nil, err + } + continue + } + + var err error + if deviceConfig { + err = cfg.handleDeviceLine(key, value, valueBytes) + } else { + err = cfg.handlePeerLine(peer, key, value, valueBytes) + } + if err != nil { + return nil, err + } + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return cfg, nil +} + +func (cfg *Config) handleDeviceLine(k, value mem.RO, valueBytes []byte) error { + switch { + case k.EqualString("private_key"): + // wireguard-go guarantees not to send zero value; private keys are already clamped. + var err error + cfg.PrivateKey, err = key.ParseNodePrivateUntyped(value) + if err != nil { + return err + } + case k.EqualString("listen_port") || k.EqualString("fwmark"): + // ignore + default: + return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy()) + } + return nil +} + +func (cfg *Config) handlePublicKeyLine(valueBytes []byte) (*Peer, error) { + p := Peer{} + var err error + p.PublicKey, err = key.ParseNodePublicUntyped(mem.B(valueBytes)) + if err != nil { + return nil, err + } + cfg.Peers = append(cfg.Peers, p) + return &cfg.Peers[len(cfg.Peers)-1], nil +} + +func (cfg *Config) handlePeerLine(peer *Peer, k, value mem.RO, valueBytes []byte) error { + switch { + case k.EqualString("endpoint"): + nk, err := key.ParseNodePublicUntyped(value) + if err != nil { + return fmt.Errorf("invalid endpoint %q for peer %q, expected a hex public key", value.StringCopy(), peer.PublicKey.ShortString()) + } + // nk ought to equal peer.PublicKey. + // Under some rare circumstances, it might not. See corp issue #3016. + // Even if that happens, don't stop early, so that we can recover from it. + // Instead, note the value of nk so we can fix as needed. + peer.WGEndpoint = nk + case k.EqualString("persistent_keepalive_interval"): + n, err := mem.ParseUint(value, 10, 16) + if err != nil { + return err + } + peer.PersistentKeepalive = uint16(n) + case k.EqualString("allowed_ip"): + ipp := netip.Prefix{} + err := ipp.UnmarshalText(valueBytes) + if err != nil { + return err + } + peer.AllowedIPs = append(peer.AllowedIPs, ipp) + case k.EqualString("protocol_version"): + if !value.EqualString("1") { + return fmt.Errorf("invalid protocol version: %q", value.StringCopy()) + } + case k.EqualString("replace_allowed_ips") || + k.EqualString("preshared_key") || + k.EqualString("last_handshake_time_sec") || + k.EqualString("last_handshake_time_nsec") || + k.EqualString("tx_bytes") || + k.EqualString("rx_bytes"): + // ignore + default: + return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy()) + } + return nil +} diff --git a/wgengine/winnet/winnet_windows.go b/wgengine/winnet/winnet_windows.go index 01e38517d..283ce5ad1 100644 --- a/wgengine/winnet/winnet_windows.go +++ b/wgengine/winnet/winnet_windows.go @@ -1,26 +1,26 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package winnet - -import ( - "fmt" - "syscall" - "unsafe" - - "github.com/go-ole/go-ole" -) - -func (v *INetworkConnection) GetAdapterId() (string, error) { - buf := ole.GUID{} - hr, _, _ := syscall.Syscall( - v.VTable().GetAdapterId, - 2, - uintptr(unsafe.Pointer(v)), - uintptr(unsafe.Pointer(&buf)), - 0) - if hr != 0 { - return "", fmt.Errorf("GetAdapterId failed: %08x", hr) - } - return buf.String(), nil -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package winnet + +import ( + "fmt" + "syscall" + "unsafe" + + "github.com/go-ole/go-ole" +) + +func (v *INetworkConnection) GetAdapterId() (string, error) { + buf := ole.GUID{} + hr, _, _ := syscall.Syscall( + v.VTable().GetAdapterId, + 2, + uintptr(unsafe.Pointer(v)), + uintptr(unsafe.Pointer(&buf)), + 0) + if hr != 0 { + return "", fmt.Errorf("GetAdapterId failed: %08x", hr) + } + return buf.String(), nil +} diff --git a/words/words.go b/words/words.go index 18efb75d7..b373ffef6 100644 --- a/words/words.go +++ b/words/words.go @@ -1,58 +1,58 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package words contains accessors for some nice words. -package words - -import ( - "bytes" - _ "embed" - "strings" - "sync" -) - -//go:embed tails.txt -var tailsTxt []byte - -//go:embed scales.txt -var scalesTxt []byte - -var ( - once sync.Once - tails, scales []string -) - -// Tails returns words about tails. -func Tails() []string { - once.Do(initWords) - return tails -} - -// Scales returns words about scales. -func Scales() []string { - once.Do(initWords) - return scales -} - -func initWords() { - tails = parseWords(tailsTxt) - scales = parseWords(scalesTxt) -} - -func parseWords(txt []byte) []string { - n := bytes.Count(txt, []byte{'\n'}) - ret := make([]string, 0, n) - for len(txt) > 0 { - word := txt - i := bytes.IndexByte(txt, '\n') - if i != -1 { - word, txt = word[:i], txt[i+1:] - } else { - txt = nil - } - if word := strings.TrimSpace(string(word)); word != "" && word[0] != '#' { - ret = append(ret, word) - } - } - return ret -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package words contains accessors for some nice words. +package words + +import ( + "bytes" + _ "embed" + "strings" + "sync" +) + +//go:embed tails.txt +var tailsTxt []byte + +//go:embed scales.txt +var scalesTxt []byte + +var ( + once sync.Once + tails, scales []string +) + +// Tails returns words about tails. +func Tails() []string { + once.Do(initWords) + return tails +} + +// Scales returns words about scales. +func Scales() []string { + once.Do(initWords) + return scales +} + +func initWords() { + tails = parseWords(tailsTxt) + scales = parseWords(scalesTxt) +} + +func parseWords(txt []byte) []string { + n := bytes.Count(txt, []byte{'\n'}) + ret := make([]string, 0, n) + for len(txt) > 0 { + word := txt + i := bytes.IndexByte(txt, '\n') + if i != -1 { + word, txt = word[:i], txt[i+1:] + } else { + txt = nil + } + if word := strings.TrimSpace(string(word)); word != "" && word[0] != '#' { + ret = append(ret, word) + } + } + return ret +} diff --git a/words/words_test.go b/words/words_test.go index e96c234d7..a9691792a 100644 --- a/words/words_test.go +++ b/words/words_test.go @@ -1,38 +1,38 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package words - -import ( - "strings" - "testing" -) - -func TestWords(t *testing.T) { - test := func(t *testing.T, words []string) { - t.Helper() - if len(words) == 0 { - t.Error("no words") - } - seen := map[string]bool{} - for _, w := range words { - if seen[w] { - t.Errorf("dup word %q", w) - } - seen[w] = true - if w == "" || strings.IndexFunc(w, nonASCIILower) != -1 { - t.Errorf("malformed word %q", w) - } - } - } - t.Run("tails", func(t *testing.T) { test(t, Tails()) }) - t.Run("scales", func(t *testing.T) { test(t, Scales()) }) - t.Logf("%v tails * %v scales = %v beautiful combinations", len(Tails()), len(Scales()), len(Tails())*len(Scales())) -} - -func nonASCIILower(r rune) bool { - if 'a' <= r && r <= 'z' { - return false - } - return true -} +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package words + +import ( + "strings" + "testing" +) + +func TestWords(t *testing.T) { + test := func(t *testing.T, words []string) { + t.Helper() + if len(words) == 0 { + t.Error("no words") + } + seen := map[string]bool{} + for _, w := range words { + if seen[w] { + t.Errorf("dup word %q", w) + } + seen[w] = true + if w == "" || strings.IndexFunc(w, nonASCIILower) != -1 { + t.Errorf("malformed word %q", w) + } + } + } + t.Run("tails", func(t *testing.T) { test(t, Tails()) }) + t.Run("scales", func(t *testing.T) { test(t, Scales()) }) + t.Logf("%v tails * %v scales = %v beautiful combinations", len(Tails()), len(Scales()), len(Tails())*len(Scales())) +} + +func nonASCIILower(r rune) bool { + if 'a' <= r && r <= 'z' { + return false + } + return true +}